交叉验证
交叉验证 (cross-validation) 是机器学习过程中避免过拟合 (overfitting) 以及优化模型参数的一个非常有效的措施。借助 scikit-learning Python机器学习库,可以非常方便地对样本进行处理、分组,以实现交叉验证。
避免过拟合的一个最简单的方法就是将样本分为训练集 (training set) 和测试集 ( test set),分别用于模型的训练和评估。Python 实现如下:
1 | import numpy as np |
然而,尽管将样本分为了两个集合,测试集的结果依然有可能会影响训练过长,产生过拟合现象。因此,为确保准确性,实际操作中通常将样本分为三个部分,即训练集、验证集 (validation set) 和测试集。在训练过程中,用验证机评估训练的效果,测试集只在最终评估时使用。
不过这种方法的缺点在于样本的利用率较低,大量样本没有用于训练。
K-折交叉验证
K-折交叉验证 (K-Fold Cross-Validation) 是一种样本利用率更高的交叉验证方法。这种方法首先将除了测试集以外的样本分为\(k\)份,用其中一份作为验证集,其余的\(k-1\)份作为训练集进行训练。之后选取新的一份作为验证集,其他的作为训练集进行训练。如此反复,直到每份都参与过验证和训练为止。如下图所示:
![K-Fold 过程示意图](https://scikit-learn.org/stable/_images/grid_search_cross_validation.png)
计算交叉验证指标
使用 cross_val_score 方法
在 scikit-learn 中,交叉验证可以使用函数cross_val_score
来完成。当cv
参数是整数时,使用 KFold 或 StratifiedKFold 策略进行交叉验证。
1 | from sklearn.model_selection import cross_val_score |
可以获得模型评分的95%:
1 | print("Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() * 2)) |
可以设置方法中的scoring
参数来更改评分模式:
1 | from sklearn import metrics |
也可以使用特定方法进行交叉验证,此时只需将方法名传入cv
参数中即可。
注意,对数据进行预处理也可以提升模型的准确度:
1
2
3
4
5
6
7
8 from sklearn import preprocessing
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target,
test_size=0.4, random_state=0)
scaler = preprocessing.StandardScaler().fit(X_train)
X_train_transformed = scaler.transform(X_train)
clf = svm.SVC(C=1).fit(X_train_trainsformed, y_train)
X_test_transformed = scaler.transform(X_test)
print(clf.score(X_test_transformed, y_test)) # 0.9333...可以使用 Pipeline 封装整个优化过程:
1
2
3
4
5 from sklearn.pipeline import make_pipeline
clf = make_pipeline(preprocessing.StandardScaler(), svm.SVC(C=1))
print(cross_val_score(clf, iris.data, iris.target, cv=cv))
# ...
# array([0.977..., 0.933..., 0.955..., 0.933..., 0.977...])可以使用
cross_val_predict
函数获得模型的预测结果。
使用 cross_validate 方法
cross_validate
方法可以使用多种验证指标,并且能够返回更多的信息,包括训练得分、你和次数和得分次数:
1 | from sklearn.model_selection import cross_validate |
交叉验证迭代器
可以通过创建交叉验证迭代器来对样本进行分组。可用的交叉验证迭代器有:
KFold
RepeatedKFold
LeaveOneOut
LeavePOut
ShuffleSplit
StratifiedKFold
StratifiedShuffleSplit
GroupKFold
LeaveOneGroupOut
LeavePGroupsOut
GroupShuffleSplit
TimeSeriesSplit
- ...