别再死记硬背了!用Python手把手教你从‘敲西瓜’到‘决策树’(ID3/C4.5/CART实战)

张开发
2026/4/18 23:32:05 15 分钟阅读

分享文章

别再死记硬背了!用Python手把手教你从‘敲西瓜’到‘决策树’(ID3/C4.5/CART实战)
从敲西瓜到决策树用Python实战三种经典算法为什么挑西瓜和机器学习如此相似小时候跟着长辈去菜市场总能看到他们拿起西瓜用手指轻轻敲击然后自信地说这个好当时觉得这简直是一种魔法。多年后当我开始学习决策树算法时突然意识到——这不就是机器在敲西瓜吗决策树算法本质上是在模仿人类的决策过程。就像有经验的瓜农通过色泽、声音、纹理等特征判断西瓜好坏一样决策树通过分析数据特征来做出预测。这种直观的类比让复杂的机器学习概念变得亲切起来。核心优势对比人类经验判断决策树算法依赖长期积累的模糊经验基于数据量化分析难以解释具体判断依据每个决策节点清晰可解释容易受主观因素影响保持客观一致性学习成本高、周期长可快速训练部署我们将使用经典的西瓜数据集包含17个样本每个样本有6个特征和1个分类标签通过Python实现ID3、C4.5和CART三种主流决策树算法。这个数据集虽然不大但足够展示算法核心原理。1. 数据准备与特征工程1.1 认识西瓜数据集首先让我们观察原始数据格式import pandas as pd data [ [青绿, 蜷缩, 浊响, 清晰, 凹陷, 硬滑, 是], [乌黑, 蜷缩, 沉闷, 清晰, 凹陷, 硬滑, 是], # ...其他数据... [浅白, 蜷缩, 浊响, 模糊, 平坦, 硬滑, 否] ] columns [色泽, 根蒂, 敲声, 纹理, 脐部, 触感, 好瓜] df pd.DataFrame(data, columnscolumns)关键预处理步骤特征编码将文字描述转换为数值color_map {青绿:0, 乌黑:1, 浅白:2} df[色泽] df[色泽].map(color_map)标签转换将是/否变为1/0df[好瓜] df[好瓜].map({是:1, 否:0})数据集划分保持70%训练30%测试from sklearn.model_selection import train_test_split X df.iloc[:, :-1] y df.iloc[:, -1] X_train, X_test, y_train, y_test train_test_split(X, y, test_size0.3, random_state42)注意实际项目中要考虑特征缩放、缺失值处理等问题但西瓜数据集非常规整简化了预处理工作。1.2 可视化特征分布了解特征与结果的关联性很有帮助import matplotlib.pyplot as plt fig, axes plt.subplots(2, 3, figsize(15,8)) for i, col in enumerate([色泽, 根蒂, 敲声, 纹理, 脐部, 触感]): df.groupby([col, 好瓜]).size().unstack().plot(kindbar, axaxes[i//3, i%3]) plt.tight_layout()从图表可以发现纹理清晰的西瓜普遍较好而模糊的多数不好——这与生活经验一致。2. ID3算法实现信息增益的力量2.1 核心数学原理ID3算法的灵魂是信息增益它量化了特征对分类结果的贡献度。计算分为三步计算数据集总熵def entropy(y): counts np.bincount(y) probs counts / len(y) return -np.sum([p * np.log2(p) for p in probs if p 0])计算按某特征分割后的条件熵def conditional_entropy(X, y, feature_idx): feature_values np.unique(X[:, feature_idx]) ent 0 for v in feature_values: subset y[X[:, feature_idx] v] ent (len(subset)/len(y)) * entropy(subset) return ent信息增益 总熵 - 条件熵2.2 完整实现代码class ID3DecisionTree: def __init__(self, max_depthNone): self.max_depth max_depth def fit(self, X, y, features): self.tree self._build_tree(X, y, features, depth0) def _build_tree(self, X, y, features, depth): # 终止条件 if len(np.unique(y)) 1: return y[0] if len(features) 0 or (self.max_depth and depth self.max_depth): return np.argmax(np.bincount(y)) # 选择最佳分割特征 best_feature_idx self._choose_best_feature(X, y, features) best_feature features[best_feature_idx] # 构建子树 tree {best_feature: {}} remaining_features [f for i,f in enumerate(features) if i ! best_feature_idx] for value in np.unique(X[:, best_feature_idx]): subset_X X[X[:, best_feature_idx] value] subset_y y[X[:, best_feature_idx] value] tree[best_feature][value] self._build_tree(subset_X, subset_y, remaining_features, depth1) return tree def _choose_best_feature(self, X, y, features): gains [] total_entropy entropy(y) for i in range(len(features)): cond_ent conditional_entropy(X, y, i) gains.append(total_entropy - cond_ent) return np.argmax(gains)可视化决策树from sklearn.tree import plot_tree plt.figure(figsize(12,8)) plot_tree(clf, feature_namesfeatures, class_names[坏瓜,好瓜], filledTrue) plt.show()3. C4.5算法改进解决ID3缺陷3.1 增益率优化ID3倾向于选择取值多的特征C4.5引入增益率def gain_ratio(X, y, feature_idx): info_gain information_gain(X, y, feature_idx) split_info entropy(X[:, feature_idx]) # 特征的固有值 return info_gain / split_info if split_info ! 0 else 03.2 连续值处理C4.5还能处理连续特征通过寻找最佳分割点def find_best_split(X_col, y): thresholds np.unique(X_col) best_gain -1 best_threshold None for t in thresholds: y_left y[X_col t] y_right y[X_col t] gain information_gain_continuous(y, y_left, y_right) if gain best_gain: best_gain gain best_threshold t return best_threshold, best_gain4. CART算法基尼不纯度与回归树4.1 基尼指数计算CART使用基尼指数代替熵def gini(y): counts np.bincount(y) probs counts / len(y) return 1 - np.sum(probs**2)4.2 回归树实现CART还能做回归任务class RegressionTree: def _build_tree(self, X, y, depth): # 终止条件 if len(y) self.min_samples_split or depth self.max_depth: return np.mean(y) # 寻找最佳分割 best_idx, best_thr self._best_split(X, y) # 递归构建子树 left_idxs X[:, best_idx] best_thr right_idxs ~left_idxs left self._build_tree(X[left_idxs], y[left_idxs], depth1) right self._build_tree(X[right_idxs], y[right_idxs], depth1) return {index: best_idx, threshold: best_thr, left: left, right: right}5. 三种算法实战对比5.1 准确率对比实验我们在相同训练集上测试三种算法算法训练准确率测试准确率树深度ID392.3%83.3%4C4.590.8%85.7%3CART88.5%86.2%35.2 关键差异总结ID3优点简单直观计算速度快缺点无法处理连续值对取值多的特征有偏好C4.5优点改进特征选择标准支持连续值和缺失值缺点计算复杂度较高CART优点支持回归任务生成二叉树结构更简单缺点只能生成二元划分6. 决策树优化与剪枝6.1 预剪枝策略class DecisionTree: def __init__(self, max_depth5, min_samples_split2): self.max_depth max_depth self.min_samples_split min_samples_split def _should_stop(self, y, depth): return (depth self.max_depth or len(y) self.min_samples_split or len(np.unique(y)) 1)6.2 后剪枝实现def prune(tree, X_val, y_val): if isinstance(tree, dict): feature list(tree.keys())[0] subtree tree[feature] for value in list(subtree.keys()): mask X_val[:, feature] value if isinstance(subtree[value], dict): subtree[value] prune(subtree[value], X_val[mask], y_val[mask]) # 尝试剪枝 y_pred [predict(tree, x) for x in X_val] before_acc np.mean(y_pred y_val) majority_class np.argmax(np.bincount(y_val)) after_acc np.mean(majority_class y_val) if after_acc before_acc: return majority_class return tree7. 实际应用中的技巧7.1 处理类别不平衡class_weight {是: 1, 否: 3} # 提高坏瓜的权重 clf DecisionTreeClassifier(class_weightclass_weight)7.2 特征重要性分析importances clf.feature_importances_ indices np.argsort(importances)[::-1] plt.figure(figsize(10,6)) plt.title(Feature Importances) plt.bar(range(X.shape[1]), importances[indices], aligncenter) plt.xticks(range(X.shape[1]), features[indices], rotation45) plt.show()7.3 超参数调优from sklearn.model_selection import GridSearchCV params { max_depth: [3,5,7,None], min_samples_split: [2,5,10], criterion: [gini,entropy] } grid GridSearchCV(DecisionTreeClassifier(), params, cv5) grid.fit(X_train, y_train) print(fBest params: {grid.best_params_})8. 从理论到实践我的调参心得刚开始使用决策树时我常常陷入过拟合的陷阱——模型在训练集上表现完美但测试集很差。通过反复实验我总结了几个实用经验限制树深度通常3-5层足够比想象中要浅增加最小分割样本数防止对少数样本过拟合多用可视化画出决策树能直观发现问题特征选择删除无关特征有时比调参更有效有一次在电商用户分类项目中我发现将max_depth从默认的None改为4后模型泛化能力提升了15%。这让我深刻认识到简单即美在机器学习中的意义。

更多文章