专栏名称: 新语数据故事汇
《新语数据故事汇,数说新语》科普数据科学、讲述数据故事,深层次挖掘数据价值。
目录
相关文章推荐
51好读  ›  专栏  ›  新语数据故事汇

数据科学必备:掌握训练集、验证集与测试集的划分

新语数据故事汇  · 公众号  ·  · 2024-07-20 21:55

正文

在数据科学与机器学习领域,一个模型能否在先前未观测到的新数据上表现良好,这种能力被称为泛化(generalization)。模型的泛化能力是衡量其有效性的重要指标,也是机器学习面临的主要挑战之一。为了实现良好的泛化能力,我们的算法必须在新数据上表现出色,而不仅仅是在训练数据上取得高精度。

模型过度学习训练数据(即过拟合)是一个常见的问题。过拟合的模型在训练数据上表现优异,但在实际部署后遇到新数据时却表现不佳。为了避免这一问题,我们需要一种机制来评估模型的泛化能力。这就是为什么在模型训练过程中,需要将数据划分为训练集、验证集和测试集。

训练集用于训练模型,使其学习数据的模式;验证集用于调整模型的参数和选择最佳模型;测试集则用于评估模型的最终性能。通过这种数据划分,我们可以有效地防止模型过拟合,并确保其在实际应用中的表现。

我们下来理解误差及泛化误差,由误差关系导致的过拟合和欠拟合以及模型的能力:

  • 误差:误差包括训练误差(training error)和泛化误差(generalization error)(也称为测试误差,test error);机器学习的核心目标是降低泛化误差。

  • 过拟合:过拟合是指训练误差和和测试误差之间的差距太大。

  • 欠拟合:欠拟合是指模型不能在训练集上获得足够低的误差。

  • 容量:模型的容量是指其拟合各种函数的能力。

容量和误差之间的典型关系如下图。训练误差和测试误差表现得非常不同。在图的左端, 训练误差和泛化误差都非常高。这是 欠拟合机制(underfitting regime)。当我们增加容量时, 训练误差减小,但是训练误差和泛化误差之间的间距却不断扩大。最终,这个间距的大小超过了训练误差的下降,我们进入到了 过拟合机制(overfitting regime),其中容量过大,超过了 最佳容量(optimal capacity)

接下来,我们将详细介绍如何科学地划分数据集,以及不同数据集在模型训练中的具体作用和意义。掌握这些数据划分技巧,将帮助你构建出更加稳健和可靠的机器学习模型。

训练-验证-测试(数据集)划分的定义

训练-验证-测试(数据集)划分是一种评估机器学习模型(无论是分类还是回归)性能的技术。你将一个给定的数据集分成三个子集。以下是对每个数据集角色的简要描述。

  • 训练数据集( Train Dataset) :用于学习(由模型进行)的数据集,即拟合机器学习模型参数的数据。

  • 验证数据集( Valid Dataset ):用于在调整模型超参数时,对训练数据集上拟合的模型进行无偏评估的数据集。也在其他形式的模型准备中发挥作用,如特征选择、阈值选择等。

  • 测试数据集( Test Dataset ):用于对训练数据集上拟合的最终模型进行无偏评估的数据集。

接下来介绍两种将数据划分为训练集、验证集和测试集的方法:

  • 随机划分

  • 使用时间序列组件划分

随机划分(Splitting Randomly)

你不能使用与训练相同的数据来评估模型的预测性能。最好使用模型之前未见过的新数据来评估模型。随机划分数据是最常用的方法,用于进行这种无偏评估。

使用 Sklearn(train_test_split)

下面的代码是如何使用 使用 Sklearn( train_test_split ) 两次来创建我们所需比例的训练集、验证集和测试集。

import pandas as pdfrom sklearn.model_selection import train_test_splitdf = pd.read_csv('Iris.csv')
train_size=0.8
X = df.drop(columns = ['Species']).copy()y = df['Species']
X_train, X_rem, y_train, y_rem = train_test_split(X,y, train_size=0.8)
test_size = 0.5X_valid, X_test, y_valid, y_test = train_test_split(X_rem,y_rem, test_size=0.5)
print(X_train.shape,y_train.shape)print(X_valid.shape,y_valid.shape)print(X_test.shape,y_test.shape)

使用 Fast_ml(train_valid_test_split)







请到「今天看啥」查看全文