Early Stopping 是 PyTorch 中用于模型训练时防止过拟合的一种技术。其基本思想是在训练过程中监控验证集上的性能,当验证集上的性能不再提升或者开始下降时,停止训练。这样可以避免在训练集上过度拟合,从而提高模型的泛化能力。
在 PyTorch 中实现 Early Stopping 非常简单,只需要在训练过程中记录验证集上的损失值和准确率,然后比较损失值和之前最好的一次的损失值的大小,如果验证集上的损失值更大,则停止训练。
Early Stopping 的优点在于可以在不使用交叉验证的情况下实现模型训练的早期停止,提高了训练效率。同时,由于在训练过程中会不断保存验证集上的性能最好的模型,因此可以通过再次训练来进一步提高模型的性能。
在使用 Early Stopping 时,需要注意记录验证集上的损失值和准确率,以及选择合适的损失函数和评估指标,以便正确地评估模型的性能。此外,还需要注意在训练过程中定期检查验证集上的性能,以免错过停止训练的最佳时机。
下面是一个简单的使用 Early Stopping 的 PyTorch 代码示例:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# 定义模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 3)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 4 * 4, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 4 * 4)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
# 超参数设置
batch_size = 100
learning_rate = 0.001
num_epochs = 20
# 数据准备
train_dataset = ...
val_dataset = ...
train_loader = DataLoader(TensorDataset(train_dataset), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(val_dataset), batch_size=batch_size, shuffle=False)
# 模型、优化器和损失函数初始化
model = Net()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
# 训练过程
best_val_loss = float('inf')
for epoch in range(num_epochs):
model.train()
train_loss = 0.0
correct = 0
total = 0
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
train_loss += loss.item()
_, predicted = torch.max(output, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
train_loss /= total
model.eval()
val_loss = 0.0
val_correct = 0
with torch.no_grad():
for data, target in val_loader:
output = model(data)
loss = criterion(output, target)
val_loss += loss.item()
_, predicted = torch.max(output, 1)
val_correct += (predicted == target).sum().item()
val_loss /= len(val_loader.dataset)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f},
点击查看更多内容
为 TA 点赞
评论
共同学习,写下你的评论
评论加载中...
作者其他优质文章
正在加载中
感谢您的支持,我会继续努力的~
扫码打赏,你说多少就多少
赞赏金额会直接到老师账户
支付方式
打开微信扫一扫,即可进行扫码打赏哦