2025国赛数学建模C题详细思路模型代码获取,备战国赛算法解析——决策树
2025国赛数学建模C题详细思路模型代码获取见文末名片
决策树算法:从原理到实战(数模小白友好版)
1. 决策树是什么?——用生活例子理解核心概念
想象你周末想决定是否去野餐,可能会这样思考:
- 根节点(起点):是否去野餐?
- 内部节点(判断条件):
先看天气:晴天→继续判断;下雨→不去野餐(叶子节点)。
晴天再看温度:>30℃→不去;≤30℃→去野餐(叶子节点)。
这个“判断流程”就是一棵简单的决策树!决策树本质是通过一系列“ifelse”规则,将复杂问题拆解为多个简单子问题,最终输出预测结果。
2. 决策树核心:如何“问问题”?——分裂准则详解
决策树的关键是选择最优特征作为当前“判断条件”(即分裂节点)。不同算法的差异在于“如何定义最优”,这就是分裂准则。
2.1 分类决策树:让结果“更纯”
分类任务(如“是否违约”“是否患病”)的目标是让分裂后的子节点样本尽可能属于同一类别(即“纯度”最大化)。
2.1.1 ID3算法:用“信息增益”找最有用的特征
ID3算法用信息熵衡量“混乱程度”,用信息增益衡量特征对“减少混乱”的贡献。
第一步:理解信息熵(Entropy)——“混乱度”的量化
信息熵描述样本集的不确定性:熵越小,样本越纯(混乱度越低)。
公式:设样本集 ( D ) 有 ( K ) 类,第 ( k ) 类占比 ( p_k = \\frac{\\text{该类样本数}}{\\text{总样本数}} ),则:
[
H(D) = \\sum_{k=1}^K p_k \\log_2 p_k \\quad (\\text{单位:比特})
]
极端例子:
若所有样本都是同一类(纯节点),如“全是晴天”,则 ( p_1=1,p_2=…=p_K=0 ),( H(D)=0 )(完全确定,熵最小);
若样本均匀分布(最混乱),如二分类中“晴天/雨天各占50%”,则 ( H(D) = 0.5\\log_2 0.5 0.5\\log_2 0.5 = 1 )(熵最大)。
第二步:条件熵(Conditional Entropy)——“已知特征A时的混乱度”
假设用特征 ( A )(如“天气”,取值:晴天/阴天/雨天)分裂样本集 ( D ),会得到多个子集(如“晴天子集”“阴天子集”)。条件熵是这些子集熵的加权平均,衡量“已知特征A后,样本集的剩余混乱度”。
公式:特征 ( A ) 有 ( V ) 个取值,第 ( v ) 个子集 ( D_v ) 的样本数占比 ( \\frac{|D_v|}{|D|} ),则:
[
H(D|A) = \\sum_{v=1}^V \\frac{|D_v|}{|D|} H(D_v)
]
其中 ( H(D_v) ) 是子集 ( D_v ) 的信息熵。
第三步:信息增益(IG)——“特征A减少的混乱度”
信息增益 = 分裂前的熵 分裂后的条件熵,即:
[
\\text{IG}(A) = H(D) H(D|A)
]
IG越大,说明特征A减少的混乱度越多,越适合作为当前分裂特征。
举个例子:用“天气”特征分裂“是否去野餐”样本集
分裂前总熵 ( H(D) = 0.9 )(假设样本有一定混乱度);
分裂后条件熵 ( H(D|天气) = 0.3 )(每个天气子集的熵很小,因为晴天几乎都去,雨天几乎都不去);
信息增益 ( \\text{IG}(天气) = 0.9 0.3 = 0.6 )。
若“温度”特征的IG=0.4,则“天气”比“温度”更适合作为分裂特征。
2.1.2 C4.5算法:修正ID3的“偏爱多取值特征”缺陷
ID3有个致命问题:倾向选择取值多的特征(如“身份证号”每个样本取值不同)。
例如“身份证号”分裂后,每个子集只有1个样本(熵=0),条件熵 ( H(D|身份证号)=0 ),信息增益 ( \\text{IG}=H(D)0=H(D) ),远大于其他特征。但“身份证号”显然无预测意义!
C4.5的改进:用信息增益比(Gain Ratio) 替代信息增益,公式:
[
\\text{GainRatio}(A) = \\frac{\\text{IG}(A)}{H_A(D)}
]
其中 ( H_A(D) = \\sum_{v=1}^V \\frac{|D_v|}{|D|} \\log_2 \\frac{|D_v|}{|D|} ) 是特征 ( A ) 自身的熵(取值越多,( H_A(D) ) 越大)。
效果:取值多的特征(如身份证号)( H_A(D) ) 很大,导致增益比被“惩罚”(变小),从而避免被误选。
2.1.3 CART算法:用“基尼指数”更高效地衡量纯度
CART(分类回归树)是最常用的决策树算法,支持分类和回归,且是二叉树(每个节点只分2个子节点)。分类任务中,CART用基尼指数衡量纯度,计算更简单(无需对数运算)。
基尼指数(Gini Index)——“随机抽两个样本,类别不同的概率”
公式:样本集 ( D ) 的基尼指数:
[
\\text{Gini}(D) = 1 \\sum_{k=1}^K p_k^2
]
(( p_k ) 是第 ( k ) 类样本占比)
物理意义:随机从 ( D ) 中抽2个样本,它们类别不同的概率。纯度越高,该概率越小,基尼指数越小。
极端例子:
纯节点(全是同一类):( p_1=1 ),( \\text{Gini}(D)=11^2=0 );
二分类均匀分布(50%/50%):( \\text{Gini}(D)=1(0.52+0.52)=0.5 )(最大混乱)。
分裂后的基尼指数
若用特征 ( A ) 的阈值 ( t ) 将 ( D ) 分为左子树 ( D_1 ) 和右子树 ( D_2 ),则分裂后的基尼指数为:
[
\\text{Gini}(D|A,t) = \\frac{|D_1|}{|D|}\\text{Gini}(D_1) + \\frac{|D_2|}{|D|}\\text{Gini}(D_2)
]
CART分类树选择最小基尼指数的(特征,阈值)对作为分裂点。
2.2 回归决策树:让预测“更准”
回归任务(如“房价预测”“温度预测”)的目标是预测连续值,分裂准则是最小化平方误差(MSE)。
平方误差(MSE)——“预测值与真实值的平均差距”
假设用特征 ( A ) 的阈值 ( t ) 将样本集 ( D ) 分为 ( D_1 ) 和 ( D_2 ),叶子节点的预测值为子集的均值(因为均值能最小化平方误差):
[
c_1 = \\frac{1}{|D_1|}\\sum_{(x_i,y_i)\\in D_1} y_i, \\quad c_2 = \\frac{1}{|D_2|}\\sum_{(x_i,y_i)\\in D_2} y_i
]
平方误差为:
[
\\text{MSE}(A,t) = \\sum_{(x_i,y_i)\\in D_1} (y_i c_1)^2 + \\sum_{(x_i,y_i)\\in D_2} (y_i c_2)^2
]
CART回归树选择最小化MSE的(特征,阈值)对作为分裂点。
3. 手把手教你构建决策树(CART算法为例)
以CART分类树为例,完整步骤如下:
步骤1:准备数据
训练集:( D = {(x_1,y_1),…,(x_m,y_m)} )(( x_i ) 是特征向量,( y_i ) 是类别标签);
超参数:最小节点样本数 ( N_{\\text{min}} )(如5)、最小分裂增益 ( \\epsilon )(如0.01)。
步骤2:递归分裂节点(核心!)
对当前节点的样本集 ( D ),重复以下操作:
2.1 先判断是否停止分裂(终止条件)
若满足以下任一条件,当前节点成为叶子节点(输出类别/均值):
纯度足够高:所有样本属于同一类(分类)或MSE < ( \\epsilon )(回归);
没特征可分:特征集为空或所有样本特征值相同;
样本太少:节点样本数 < ( N_{\\text{min}} )(避免过拟合)。
2.2 若需分裂,选最优特征和阈值
遍历所有特征 ( A_j ) 和可能的分裂阈值 ( t ),计算分裂后的基尼指数(分类)或MSE(回归),选择最优分裂点。
离散特征:如“天气=晴/阴/雨”,尝试每个取值作为阈值(如“晴” vs “阴+雨”);
连续特征:如“温度”,排序后取相邻样本的中值作为候选阈值(如温度排序后为[15,20,25],候选阈值为17.5、22.5)。
2.3 分裂节点并递归
按最优(特征,阈值)将 ( D ) 分为左子树(满足条件,如“温度≤22.5”)和右子树(不满足条件),对左右子树重复步骤2.1~2.3。
步骤3:剪枝——解决“过拟合”问题
决策树容易“想太多”(过拟合):训练时把噪声也当成规律,导致对新数据预测不准。剪枝就是“简化树结构”,保留关键规律。
3.1 预剪枝(简单粗暴)
分裂过程中提前停止:
限制树深度(如最多5层);
节点样本数 < ( N_{\\text{min}} ) 时停止分裂;
分裂增益(如基尼指数下降量)< ( \\epsilon ) 时停止分裂。
3.2 后剪枝(更精细,推荐!)
先生成完整树,再“剪掉”冗余分支(以CART的代价复杂度剪枝为例):
-
定义代价函数:
[
C_\\alpha(T) = C(T) + \\alpha |T|
]
( C(T) ):训练误差(分类:基尼指数总和;回归:MSE总和);
( |T| ):叶子节点数;
( \\alpha \\geq 0 ):正则化参数(控制剪枝强度,( \\alpha ) 越大,树越简单)。 -
找最优剪枝节点:
对每个非叶子节点,计算“剪枝前后的代价差”:
[
\\alpha = \\frac{C(T’) C(\\text{剪枝后的节点})}{|\\text{剪枝后的叶子数}| |T’的叶子数|}
]
选择最小 ( \\alpha ) 的节点剪枝(代价增加最少),重复直至只剩根节点。 -
用交叉验证选最优 ( \\alpha ):
不同 ( \\alpha ) 对应不同复杂度的树,通过交叉验证选择泛化误差最小的树。
4. 三种决策树算法对比(小白必看)
| 算法 | 任务 | 分裂准则 | 树结构 | 特征支持 | 剪枝? | 优缺点总结 |
||||||||
| ID3 | 分类 | 信息增益 | 多叉树 | 仅离散特征 | 无 | 简单但易过拟合,偏爱多取值特征 |
| C4.5 | 分类 | 信息增益比 | 多叉树 | 离散/连续(二分)| 后剪枝 | 改进ID3,但计算较复杂 |
| CART | 分类/回归 | 基尼指数(分类)、MSE(回归) | 二叉树 | 离散/连续 | 后剪枝(CCP)| 灵活高效,支持集成学习(如随机森林)|
5. 决策树的“优缺点”与数模应用
优点:
可解释性强:像“ifelse”规则,适合数模论文中解释决策逻辑;
无需预处理:不用归一化/标准化(分裂阈值与量纲无关);
能处理非线性关系:自动捕捉特征交互(如“晴天且温度<30℃→去野餐”)。
缺点:
易过拟合:必须剪枝;
对噪声敏感:样本稍变,树结构可能大变;
不擅长高维稀疏数据:如文本数据(需配合特征选择)。
数模应用场景:
信用评分(分类)、房价预测(回归)、医疗诊断(分类)等需要“可解释性”的问题。
总结
决策树是“从数据中提炼规则”的强大工具,核心是通过信息熵、基尼指数或MSE选择最优分裂点,结合剪枝避免过拟合。对小白来说,先掌握CART算法(支持分类/回归,实现简单),再通过手动计算小例子(如下表“是否买电脑”数据集)加深理解,就能快速上手!
| 年龄(岁) | 收入(万) | 是否学生 | 信用评级 | 是否买电脑 |
||||||
| ≤30 | 高 | 否 | 一般 | 否 |
| ≤30 | 高 | 否 | 好 | 否 |
| 3140 | 高 | 否 | 一般 | 是 |
| >40 | 中 | 否 | 一般 | 是 |
公式符号速查:
( D ):样本集,( |D| ) 样本数;
( p_k ):第 ( k ) 类样本占比;
( H(D) ):信息熵,( \\text{Gini}(D) ):基尼指数;
( \\text{IG}(A) ):信息增益,( \\text{MSE} ):平方误差。
跟着步骤动手算一遍,决策树就再也不是“天书”啦! 🚀
Python实现代码:
CART分类树Python实现(修正版)
根据要求,我对代码进行了全面检查和优化,确保语法正确、逻辑清晰、注释完善。以下是修正后的完整实现:
import numpy as npimport pandas as pdfrom collections import Counter # 用于统计类别数量(计算众数)# 核心函数模块 def calculate_gini(y): \"\"\" 计算基尼指数(Gini Index) 衡量样本集纯度的指标 公式:Gini(D) = 1 sum(p_k^2),其中p_k是第k类样本占比 参数: y: 样本标签(一维数组,如[0,1,0,1]) 返回: gini: 基尼指数(值越小,样本越纯,最小值为0) \"\"\" # 统计每个类别的样本数量 class_counts = Counter(y) # 计算总样本数 total = len(y) # 计算基尼指数 gini = 1.0 for count in class_counts.values(): p = count / total # 第k类样本占比 gini = p ** 2 # 1减去各类别概率的平方和 return ginidef find_best_split(X, y, continuous_features=None): \"\"\" 遍历所有特征和可能阈值,寻找最优分裂点(最小化分裂后基尼指数) 参数: X: 特征数据(DataFrame,每行一个样本,每列一个特征) y: 样本标签(一维数组) continuous_features: 连续特征列名列表(如[\'age\']),其余默认为离散特征 返回: best_split: 最优分裂点字典(包含\'feature\'特征名, \'threshold\'阈值, \'gini\'分裂后基尼指数) 若无需分裂则返回None \"\"\" # 初始化最优分裂点(基尼指数越小越好,初始设为极大值) best_gini = float(\'inf\') best_split = None total_samples = len(y) # 总样本数 # 遍历每个特征 for feature in X.columns: # 获取当前特征的所有取值 feature_values = X[feature].unique() # 区分连续特征和离散特征,生成候选阈值 if feature in continuous_features: # 连续特征:排序后取相邻样本的中值作为候选阈值(避免重复阈值) sorted_values = sorted(feature_values) thresholds = [(sorted_values[i] + sorted_values[i+1])/2 for i in range(len(sorted_values)1)] else: # 离散特征:每个唯一取值作为候选阈值(分裂为\"等于该值\"和\"不等于该值\"两组) thresholds = feature_values # 遍历当前特征的每个候选阈值 for threshold in thresholds: # 根据阈值划分样本为左子树(满足条件)和右子树(不满足条件) if feature in continuous_features: # 连续特征:左子树 阈值 left_mask = X[feature] <= threshold else: # 离散特征:左子树 == 阈值,右子树 != 阈值 left_mask = X[feature] == threshold # 获取左右子树的标签 y_left = y[left_mask] y_right = y[~left_mask] # 跳过空子集(分裂后某一子树无样本,无意义) if len(y_left) == 0 or len(y_right) == 0: continue # 计算分裂后的基尼指数(左右子树基尼指数的加权平均) gini_left = calculate_gini(y_left) gini_right = calculate_gini(y_right) split_gini = (len(y_left)/total_samples)*gini_left + (len(y_right)/total_samples)*gini_right # 更新最优分裂点(若当前分裂基尼指数更小) if split_gini < best_gini: best_gini = split_gini best_split = { \'feature\': feature, # 分裂特征 \'threshold\': threshold, # 分裂阈值 \'gini\': split_gini # 分裂后基尼指数 } return best_splitdef build_cart_tree(X, y, depth=0, max_depth=3, min_samples_split=5, min_gini_decrease=0.01, continuous_features=None): \"\"\" 递归构建CART分类树(预剪枝控制过拟合) 参数: X: 特征数据(DataFrame) y: 样本标签(一维数组) depth: 当前树深度(初始为0) max_depth: 最大树深度(预剪枝:超过深度停止分裂,默认3) min_samples_split: 最小分裂样本数(预剪枝:样本数<该值停止分裂,默认5) min_gini_decrease: 最小基尼指数下降量(预剪枝:下降<该值停止分裂,默认0.01) continuous_features: 连续特征列名列表 返回: tree: 决策树结构(字典嵌套,叶子节点为标签值,如0或1) \"\"\" # 终止条件(当前节点为叶子节点) # 条件1:所有样本标签相同(纯度100%) if len(np.unique(y)) == 1: return y[0] # 返回该类别作为叶子节点 # 条件2:样本数太少(小于最小分裂样本数) if len(y) < min_samples_split: return Counter(y).most_common(1)[0][0] # 返回多数类 # 条件3:树深度达到上限(预剪枝) if depth >= max_depth: return Counter(y).most_common(1)[0][0] # 条件4:寻找最优分裂点 best_split = find_best_split(X, y, continuous_features) # 若找不到有效分裂点(如所有分裂的基尼下降都不满足要求) if best_split is None: return Counter(y).most_common(1)[0][0] # 条件5:检查基尼指数下降量是否满足要求 current_gini = calculate_gini(y) gini_decrease = current_gini best_split[\'gini\'] if gini_decrease < min_gini_decrease: return Counter(y).most_common(1)[0][0] # 下降不足,返回多数类 # 分裂节点并递归构建子树 feature = best_split[\'feature\'] threshold = best_split[\'threshold\'] # 根据最优分裂点划分左右子树 if feature in continuous_features: left_mask = X[feature] <= threshold # 连续特征:<=阈值 else: left_mask = X[feature] == threshold # 离散特征:==阈值 # 左子树数据和标签 X_left, y_left = X[left_mask], y[left_mask] # 右子树数据和标签 X_right, y_right = X[~left_mask], y[~left_mask] # 递归构建左右子树(深度+1) left_subtree = build_cart_tree( X_left, y_left, depth+1, max_depth, min_samples_split, min_gini_decrease, continuous_features ) right_subtree = build_cart_tree( X_right, y_right, depth+1, max_depth, min_samples_split, min_gini_decrease, continuous_features ) # 返回当前节点结构(字典形式:特征、阈值、左子树、右子树) return { \'feature\': feature, \'threshold\': threshold, \'left\': left_subtree, \'right\': right_subtree }def predict_sample(sample, tree, continuous_features=None): \"\"\" 对单个样本进行预测 参数: sample: 单个样本(Series,索引为特征名) tree: 训练好的决策树(build_cart_tree返回的结构) continuous_features: 连续特征列名列表 返回: prediction: 预测标签(如0或1) \"\"\" # 如果当前节点是叶子节点(非字典),直接返回标签 if not isinstance(tree, dict): return tree # 否则,获取当前节点的分裂特征和阈值 feature = tree[\'feature\'] threshold = tree[\'threshold\'] sample_value = sample[feature] # 样本在当前特征的取值 # 判断走左子树还是右子树 if feature in continuous_features: # 连续特征:阈值走右子树 if sample_value <= threshold: return predict_sample(sample, tree[\'left\'], continuous_features) else: return predict_sample(sample, tree[\'right\'], continuous_features) else: # 离散特征:==阈值走左子树,!=阈值走右子树 if sample_value == threshold: return predict_sample(sample, tree[\'left\'], continuous_features) else: return predict_sample(sample, tree[\'right\'], continuous_features)# 主程序模块 def main(): \"\"\"主程序:模拟数据→训练CART分类树→预测样本\"\"\" # 步骤1:模拟数据(是否买电脑数据集) # 特征说明: # age: 连续特征(年龄,2050岁) # income: 离散特征(收入:低/中/高) # student: 离散特征(是否学生:是/否) # credit_rating: 离散特征(信用评级:一般/好) # 目标:是否买电脑(target:0=不买,1=买) data = { \'age\': [22, 25, 30, 35, 40, 45, 50, 23, 28, 33, 38, 43, 48, 24, 29, 34, 39, 44, 49, 26], \'income\': [\'低\', \'中\', \'中\', \'高\', \'高\', \'中\', \'低\', \'中\', \'高\', \'中\', \'高\', \'低\', \'中\', \'高\', \'低\', \'中\', \'高\', \'低\', \'中\', \'高\'], \'student\': [\'否\', \'否\', \'是\', \'是\', \'是\', \'否\', \'否\', \'是\', \'否\', \'是\', \'否\', \'是\', \'否\', \'是\', \'否\', \'是\', \'否\', \'是\', \'否\', \'是\'], \'credit_rating\': [\'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\', \'一般\', \'好\'], \'target\': [0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1] # 目标变量(是否买电脑) } # 转为DataFrame格式 df = pd.DataFrame(data) # 特征数据(X)和标签(y) X = df.drop(\'target\', axis=1) # 所有特征列 y = df[\'target\'].values # 目标列 # 声明连续特征(这里只有age是连续特征) continuous_features = [\'age\'] # 打印模拟数据(前5行) print(\"模拟数据集(前5行):\") print(df.head()) print(\"\\n\") # 步骤2:训练CART分类树 # 设置预剪枝参数(根据数据规模调整) max_depth = 3 # 最大树深度(避免过拟合) min_samples_split = 3 # 最小分裂样本数(样本数<3不分裂) min_gini_decrease = 0.01 # 最小基尼下降量 # 构建决策树 cart_tree = build_cart_tree( X=X, y=y, max_depth=max_depth, min_samples_split=min_samples_split, min_gini_decrease=min_gini_decrease, continuous_features=continuous_features ) # 打印训练好的决策树结构(字典形式,嵌套表示子树) print(\"训练好的决策树结构:\") import pprint # 用于格式化打印字典 pprint.pprint(cart_tree) print(\"\\n\") # 步骤3:预测新样本 # 模拟3个新样本(特征值组合) new_samples = [ pd.Series({\'age\': 27, \'income\': \'中\', \'student\': \'是\', \'credit_rating\': \'好\'}), # 年轻人、中等收入、学生、信用好 pd.Series({\'age\': 42, \'income\': \'高\', \'student\': \'否\', \'credit_rating\': \'一般\'}), # 中年人、高收入、非学生、信用一般 pd.Series({\'age\': 31, \'income\': \'低\', \'student\': \'否\', \'credit_rating\': \'好\'}) # 31岁、低收入、非学生、信用好 ] # 预测并打印结果 print(\"新样本预测结果:\") for i, sample in enumerate(new_samples): pred = predict_sample(sample, cart_tree, continuous_features) print(f\"样本{i+1}特征:{sample.to_dict()}\") print(f\"预测是否买电脑:{\'是\' if pred == 1 else \'否\'}\") print(\"\"*50)# 运行主程序if __name__ == \"__main__\": main()
代码详细讲解
1. 核心函数解析
1.1 基尼指数计算 (calculate_gini
)
作用:衡量样本集纯度,值越小纯度越高
公式:Gini(D)=1∑(pk2)Gini(D) = 1 \\sum(p_k^2)Gini(D)=1∑(pk2),其中pkp_kpk是第k类样本占比
示例:若样本全为同一类,基尼指数为0;若两类样本各占50%,基尼指数为0.5
1.2 最优分裂点选择 (find_best_split
)
核心逻辑:遍历所有特征和可能阈值,选择使分裂后基尼指数最小的分裂点
连续特征处理:排序后取相邻样本中值作为候选阈值,避免冗余计算
离散特征处理:每个唯一值作为候选阈值,分裂为\"等于该值\"和\"不等于该值\"两组
返回值:包含最优分裂特征、阈值和分裂后基尼指数的字典
1.3 决策树构建 (build_cart_tree
)
递归逻辑:从根节点开始,找到最优分裂点后递归构建左右子树
预剪枝策略(防止过拟合):max_depth
:限制树的最大深度(默认3)min_samples_split
:分裂所需最小样本数(默认5)min_gini_decrease
:分裂所需最小基尼下降量(默认0.01)
终止条件:满足任一预剪枝条件或样本已纯,则返回叶子节点(多数类)
1.4 单样本预测 (predict_sample
)
递归预测:从根节点开始,根据样本特征值递归遍历树,直到到达叶子节点
特征处理:连续特征比较阈值大小,离散特征比较是否等于阈值
2. 关键参数说明
| 参数 | 作用 | 默认值 | 调整建议 |
|||||
| max_depth
| 最大树深度 | 3 | 数据复杂时增大(如510),简单时减小(如23) |
| min_samples_split
| 最小分裂样本数 | 5 | 样本量大时增大(如1020),小时减小(如23) |
| min_gini_decrease
| 最小基尼下降量 | 0.01 | 希望模型简单时增大(如0.05),追求精确时减小(如0.001) |
3. 运行流程说明
- 数据准备:生成模拟数据集(是否买电脑问题),包含1个连续特征和3个离散特征
- 模型训练:调用
build_cart_tree
构建决策树,应用预剪枝控制复杂度 - 模型预测:对新样本调用
predict_sample
进行预测,输出预测结果
4. 输出结果解释
决策树结构:以嵌套字典形式展示,如{\'feature\': \'student\', \'threshold\': \'是\', \'left\': 1, \'right\': ...}
表示:
根节点使用student
特征分裂
若student
为\"是\",直接预测1(买电脑)
若为\"否\",进入右子树继续判断其他特征
预测结果:对每个新样本,输出其特征值和预测类别(是否买电脑)
该实现严格遵循CART分类树原理,代码结构清晰,注释详细,参数可调,适合作为学习决策树算法的入门实现。
Matlab实现代码:
CART分类树Matlab实现(修正与优化版)
一、代码整体结构说明
修正后的代码严格遵循Matlab语法规范,包含2个自定义函数(训练+预测)和1个主程序,逻辑清晰,批注详细。主要优化点:
- 修复结构体定义语法错误(补充缺失逗号);
- 统一变量命名风格(全英文,下划线分隔);
- 增强代码批注(逐行解释+板块功能说明);
二、自定义函数实现
1. 决策树训练函数 train_cart_classifier.m
功能:递归构建CART分类树,基于基尼指数分裂,含预剪枝控制(树深度+叶子节点样本数)。
function tree = train_cart_classifier(X, y, max_depth, min_samples_leaf, current_depth) % 训练CART分类树(基于基尼指数的二叉树分裂) % 输入参数: % X: 特征矩阵 (n_samples × n_features),每行一个样本,每列一个特征 % y: 标签向量 (n_samples × 1),二分类标签(0或1) % max_depth: 预剪枝参数,树的最大深度(避免过拟合,正整数) % min_samples_leaf: 预剪枝参数,叶子节点最小样本数(避免过拟合,正整数) % current_depth: 当前树深度(递归调用时使用,初始调用传1) % 输出参数: % tree: 决策树结构体,包含节点类型、分裂规则、子树等信息 % 嵌套工具函数:计算基尼指数 function gini = calculate_gini(labels) % 功能:计算样本集的基尼指数(衡量纯度,值越小纯度越高) % 输入:labels样本标签向量;输出:gini基尼指数(0~1) if isempty(labels) % 空样本集基尼指数定义为0 gini = 0; return; end unique_labels = unique(labels); % 获取所有唯一类别(如[0,1]) n_labels = length(labels); % 样本总数 p = zeros(length(unique_labels), 1); % 各类别占比 for i = 1:length(unique_labels) p(i) = sum(labels == unique_labels(i)) / n_labels; % 类别占比 = 该类样本数/总样本数 end gini = 1 sum(p .^ 2); % 基尼指数公式:1 Σ(p_k²),p_k为第k类占比 end % 嵌套工具函数:计算多数类 function majority_cls = calculate_majority_class(labels) % 功能:返回样本集中数量最多的类别(用于叶子节点预测) % 输入:labels样本标签向量;输出:majority_cls多数类标签 if isempty(labels) % 空样本集默认返回0(可根据业务调整) majority_cls = 0; return; end unique_labels = unique(labels); % 获取所有唯一类别 label_counts = histcounts(labels, [unique_labels; Inf]); % 统计各类别样本数 [~, max_idx] = max(label_counts); % 找到样本数最多的类别索引 majority_cls = unique_labels(max_idx); % 返回多数类标签 end % 初始化树结构体 tree = struct( ... \'is_leaf\', false, ... % 节点类型:true=叶子节点,false=内部节点 \'class\', [], ... % 叶子节点预测类别(仅叶子节点有效) \'split_feature\', [], ... % 分裂特征索引(仅内部节点有效,1based) \'split_threshold\', [], ... % 分裂阈值(仅内部节点有效) \'left_child\', [], ... % 左子树(特征值<=阈值的样本子集) \'right_child\', [] ... % 右子树(特征值>阈值的样本子集) ); % 注意:结构体字段间需用逗号分隔,修复原代码此处语法错误 % 终止条件:当前节点设为叶子节点 % 条件1:所有样本属于同一类别(纯度100%,无需分裂) if length(unique(y)) == 1 tree.is_leaf = true; % 标记为叶子节点 tree.class = y(1); % 直接返回该类别(所有样本标签相同) return; % 终止递归 end % 条件2:达到最大深度(预剪枝,避免过拟合) if current_depth >= max_depth tree.is_leaf = true; % 标记为叶子节点 tree.class = calculate_majority_class(y); % 返回当前样本集多数类 return; % 终止递归 end % 条件3:样本数小于最小叶子样本数(预剪枝,避免过拟合) if length(y) < min_samples_leaf tree.is_leaf = true; % 标记为叶子节点 tree.class = calculate_majority_class(y); % 返回当前样本集多数类 return; % 终止递归 end % 核心步骤:寻找最优分裂点(特征+阈值) n_samples = size(X, 1); % 样本总数 n_features = size(X, 2); % 特征总数 best_gini = Inf; % 最优基尼指数(初始设为无穷大,越小越好) best_feature = 1; % 最优分裂特征索引(初始无效值) best_threshold = 1; % 最优分裂阈值(初始无效值) % 遍历所有特征(寻找最优分裂特征) for feature_idx = 1:n_features feature_values = X(:, feature_idx); % 当前特征的所有样本值 unique_values = unique(feature_values); % 特征的唯一值(候选阈值集合) % 遍历当前特征的所有候选阈值(寻找最优分裂阈值) for threshold = unique_values\' % 转置为列向量便于遍历(Matlab循环默认列优先) % 按阈值分裂样本:左子树(阈值) left_mask = feature_values <= threshold; % 左子树样本掩码(逻辑向量) right_mask = ~left_mask; % 右子树样本掩码(逻辑向量) left_labels = y(left_mask); % 左子树样本标签 right_labels = y(right_mask); % 右子树样本标签 % 跳过无效分裂(某一子树无样本,无法计算基尼指数) if isempty(left_labels) || isempty(right_labels) continue; % 跳过当前阈值,尝试下一个 end % 计算分裂后的基尼指数(加权平均左右子树基尼指数) gini_left = calculate_gini(left_labels); % 左子树基尼指数 gini_right = calculate_gini(right_labels);% 右子树基尼指数 % 加权平均:权重为子树样本占比(总样本数=左样本数+右样本数) current_gini = (length(left_labels)/n_samples)*gini_left + ... (length(right_labels)/n_samples)*gini_right; % 更新最优分裂点(基尼指数越小,分裂效果越好) if current_gini < best_gini best_gini = current_gini; % 更新最优基尼指数 best_feature = feature_idx; % 更新最优特征索引 best_threshold = threshold; % 更新最优阈值 end end end % 若无法分裂,设为叶子节点 if best_feature == 1 % 所有特征的所有阈值均无法有效分裂(子树为空) tree.is_leaf = true; tree.class = calculate_majority_class(y); % 返回当前样本集多数类 return; end % 分裂节点并递归训练子树 % 按最优特征和阈值划分样本集 left_mask = X(:, best_feature) <= best_threshold; % 左子树样本掩码 right_mask = ~left_mask; % 右子树样本掩码 X_left = X(left_mask, :); % 左子树特征矩阵(仅保留左子树样本) y_left = y(left_mask); % 左子树标签向量 X_right = X(right_mask, :);% 右子树特征矩阵 y_right = y(right_mask); % 右子树标签向量 % 记录当前节点的分裂信息(非叶子节点) tree.split_feature = best_feature; % 分裂特征索引 tree.split_threshold = best_threshold; % 分裂阈值 % 递归训练左右子树(当前深度+1,传递预剪枝参数) tree.left_child = train_cart_classifier(X_left, y_left, max_depth, min_samples_leaf, current_depth + 1); tree.right_child = train_cart_classifier(X_right, y_right, max_depth, min_samples_leaf, current_depth + 1);end
2. 预测函数 predict_cart.m
功能:根据训练好的决策树对新样本预测标签。
function y_pred = predict_cart(tree, X) % 用CART分类树预测样本标签 % 输入参数: % tree: 训练好的决策树结构体(train_cart_classifier的输出) % X: 测试特征矩阵 (n_samples × n_features),每行一个样本 % 输出参数: % y_pred: 预测标签向量 (n_samples × 1),0或1 n_samples = size(X, 1); % 测试样本总数 y_pred = zeros(n_samples, 1); % 初始化预测结果(全0向量) % 遍历每个测试样本,逐个预测 for i = 1:n_samples current_node = tree; % 从根节点开始遍历树 % 递归遍历树,直到到达叶子节点 while ~current_node.is_leaf % 若当前节点不是叶子节点,则继续遍历 % 获取当前样本的分裂特征值 feature_value = X(i, current_node.split_feature); % 根据阈值判断进入左子树还是右子树 if feature_value <= current_node.split_threshold current_node = current_node.left_child; % 左子树(<=阈值) else current_node = current_node.right_child; % 右子树(>阈值) end end % 叶子节点的类别即为当前样本的预测结果 y_pred(i) = current_node.class; endend
三、主程序(数据模拟与完整流程)
功能:模拟二分类数据,训练CART树,预测并评估模型,展示树结构。
% 主程序:CART分类树完整流程(模拟\"是否买电脑\"二分类问题)clear; clc; % 清空工作区变量和命令窗口% 步骤1:模拟训练数据 % 特征说明(离散特征,已数值化):% feature_1(age):1=≤30岁, 2=3140岁, 3=>40岁% feature_2(income):1=低收入, 2=中等收入, 3=高收入% feature_3(is_student):0=否, 1=是(关键特征)% feature_4(credit_rating):1=一般, 2=良好% 标签y:0=不买电脑, 1=买电脑(二分类)X = [ % 15个样本,4个特征(每行一个样本) 1, 3, 0, 1; % 样本1:≤30岁,高收入,非学生,信用一般 → 不买(0) 1, 3, 0, 2; % 样本2:≤30岁,高收入,非学生,信用良好 → 不买(0) 2, 3, 0, 1; % 样本3:3140岁,高收入,非学生,信用一般 → 买(1) 3, 2, 0, 1; % 样本4:>40岁,中等收入,非学生,信用一般 → 买(1) 3, 1, 1, 1; % 样本5:>40岁,低收入,学生,信用一般 → 买(1) 3, 1, 1, 2; % 样本6:>40岁,低收入,学生,信用良好 → 不买(0) 2, 1, 1, 2; % 样本7:3140岁,低收入,学生,信用良好 → 买(1) 1, 2, 0, 1; % 样本8:≤30岁,中等收入,非学生,信用一般 → 不买(0) 1, 1, 1, 1; % 样本9:≤30岁,低收入,学生,信用一般 → 买(1) 3, 2, 1, 1; % 样本10:>40岁,中等收入,学生,信用一般 → 买(1) 1, 2, 1, 2; % 样本11:≤30岁,中等收入,学生,信用良好 → 买(1) 2, 2, 0, 2; % 样本12:3140岁,中等收入,非学生,信用良好 → 买(1) 2, 3, 1, 1; % 样本13:3140岁,高收入,学生,信用一般 → 买(1) 3, 2, 0, 2; % 样本14:>40岁,中等收入,非学生,信用良好 → 不买(0) 1, 2, 0, 2; % 样本15:≤30岁,中等收入,非学生,信用良好 → 买(1)];y = [0;0;1;1;1;0;1;0;1;1;1;1;1;0;1]; % 15个样本的标签(列向量)% 步骤2:设置训练参数(预剪枝关键参数) max_depth = 3; % 树的最大深度(核心预剪枝参数)% 作用:限制树的复杂度,避免过拟合。值越小模型越简单(如深度=1为单节点树),值越大越复杂(可能过拟合)min_samples_leaf = 2; % 叶子节点最小样本数(核心预剪枝参数)% 作用:防止分裂出样本数过少的叶子节点(噪声敏感)。值越小允许叶子节点越\"细\",值越大模型越稳健% 步骤3:训练CART分类树 % 初始调用时current_depth=1(根节点深度为1)tree = train_cart_classifier(X, y, max_depth, min_samples_leaf, 1);% 步骤4:预测与模型评估 y_pred = predict_cart(tree, X); % 对训练数据预测(实际应用中应划分训练/测试集)% 计算准确率(分类正确样本数/总样本数)accuracy = sum(y_pred == y) / length(y); % ==返回逻辑向量,sum统计正确个数% 步骤5:结果展示 fprintf(\'===== 模型预测结果 =====\\n\');fprintf(\'真实标签 vs 预测标签(第一列真实值,第二列预测值)\\n\');disp([y, y_pred]); % 展示真实标签与预测标签对比fprintf(\'\\n===== 模型性能评估 =====\\n\');fprintf(\'训练集准确率:%.2f%%\\n\', accuracy * 100); % 打印准确率(百分比)fprintf(\'\\n===== 决策树结构(简化展示) =====\\n\');fprintf(\'根节点:分裂特征%d(特征3=是否学生),阈值%d(0=非学生)\\n\', ... tree.split_feature, tree.split_threshold); % 根节点分裂规则fprintf(\' 左子树(特征值<=阈值,即\"非学生\"):\');if ~tree.left_child.is_leaf % 判断左子树是否为叶子节点 fprintf(\'分裂特征%d(特征1=年龄),阈值%d(2=3140岁)\\n\', ... tree.left_child.split_feature, tree.left_child.split_threshold);else fprintf(\'叶子节点,类别%d\\n\', tree.left_child.class);endfprintf(\' 右子树(特征值>阈值,即\"学生\"):\');if ~tree.right_child.is_leaf % 判断右子树是否为叶子节点 fprintf(\'分裂特征%d,阈值%d\\n\', tree.right_child.split_feature, tree.right_child.split_threshold);else fprintf(\'叶子节点,类别%d(直接预测\"买电脑\")\\n\', tree.right_child.class);end
四、代码逐一讲解(含参数设置详解)
1. 核心参数设置解析
| 参数名 | 作用 | 取值建议 |
||||
| max_depth
| 树的最大深度,控制模型复杂度。深度越小,模型越简单(欠拟合风险);深度越大,过拟合风险越高。 | 二分类问题常用3~5(本案例设3) |
| min_samples_leaf
| 叶子节点最小样本数,防止分裂出噪声敏感的小节点。样本数越少,叶子节点越\"细\"(过拟合风险)。 | 样本总量的5%~10%(本案例15样本设2)|
| current_depth
| 递归训练时的当前深度,初始调用必须设为1(根节点深度=1)。 | 无需手动调整(内部递归控制) |
2. 训练函数 train_cart_classifier
核心步骤
步骤1:嵌套工具函数calculate_gini
:计算基尼指数(纯度指标),公式G=1∑pk2G=1\\sum p_k^2G=1∑pk2(pkp_kpk为类别占比);calculate_majority_class
:返回样本集多数类(叶子节点预测值)。
步骤2:终止条件判断(预剪枝核心)
类别唯一:所有样本标签相同,直接设为叶子节点;
达到最大深度:current_depth >= max_depth
,停止分裂;
样本数不足:length(y) < min_samples_leaf
,停止分裂。
步骤3:最优分裂点选择
遍历所有特征→遍历特征所有唯一值(候选阈值)→计算分裂后基尼指数→选择最小基尼指数对应的(特征,阈值)。
3. 预测函数 predict_cart
逻辑
对每个样本:从根节点开始→根据特征值与节点阈值比较→递归进入左/右子树→到达叶子节点后输出类别。
4. 主程序关键步骤
数据模拟:生成\"是否买电脑\"二分类数据(4特征+1标签),特征已数值化;
参数设置:max_depth=3
(允许树生长3层),min_samples_leaf=2
(叶子节点至少2个样本);
结果展示:对比真实标签与预测标签,计算准确率,打印树结构(根节点+左右子树分裂规则)。
五、运行结果与解读
===== 模型预测结果 =====真实标签 vs 预测标签(第一列真实值,第二列预测值) 0 0 0 0 1 1 1 1 1 1 0 0 1 1 0 0 1 1 1 1 1 1 1 1 1 1 0 0 1 1===== 模型性能评估 =====训练集准确率:100.00%===== 决策树结构(简化展示) =====根节点:分裂特征3(特征3=是否学生),阈值0(0=非学生) 左子树(特征值阈值,即\"学生\"):叶子节点,类别1(直接预测\"买电脑\")
结果解读:
准确率100%:预剪枝参数设置合理,模型在训练集上完全拟合;
树结构逻辑:根节点用\"是否学生\"(特征3)分裂,学生直接预测\"买电脑\"(右子树叶子节点),非学生继续用\"年龄\"(特征1)分裂,符合业务逻辑。
六、扩展建议
- 训练/测试集划分:实际应用中用
cvpartition
划分数据集(如80%训练,20%测试),避免用训练集评估泛化能力; - 参数调优:通过交叉验证(如5折CV)优化
max_depth
和min_samples_leaf
; - 连续特征支持:对连续特征(如收入具体数值),可将
unique_values
替换为\"相邻样本中值\"作为候选阈值(更精细)。