Pytorch 4 :Save & Load & Pretrain
导言
- 保存与加载模型:学习如何保存训练好的模型,并在需要时加载模型进行推理或继续训练。
- 迁移学习:学习如何使用预训练模型进行迁移学习,微调模型以适应新的任务。
- 常用预训练模型:介绍PyTorch中常用的预训练模型,如ResNet、VGG等。
保存与读取¶
在 PyTorch 训练模型时,我们通常需要保存模型,以便后续继续训练或进行推理(Inference)。PyTorch 提供了两种常见的保存方式:
- 仅保存模型参数(推荐方式)
- 保存完整模型(包括结构和参数)
下面是详细的方法和代码示例。
1. 仅保存模型参数(推荐)¶
这种方式只保存模型的 state_dict
(即模型的参数),但不包含模型结构。优点是更灵活,加载时可以创建相同结构的模型,再加载参数。
1.1 保存模型参数¶
这样会把model.pth
作为文件保存,其中仅包含模型的参数(权重和偏置)。
1.2 加载模型参数¶
加载时需要先定义模型的结构,然后再加载 state_dict
:
import torch
import torch.nn as nn
# 重新定义模型结构(必须和保存时的模型结构一致)
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 20)
def forward(self, x):
return self.fc(x)
# 创建模型实例
model = MyModel()
# 加载保存的参数
model.load_state_dict(torch.load("model.pth"))
# 切换为评估模式(如果是推理)
model.eval()
注意: - 需要 手动创建模型结构,否则无法正确加载参数。 -
model.eval()
让模型进入推理模式(影响BatchNorm
和Dropout
)。推理时必须调用 model.eval(),否则 Dropout 仍然生效,BN 计算方式也不同,可能导致预测结果不稳定!
2. 保存完整模型(包括结构+参数)¶
如果希望保存整个模型(包括结构和参数),可以直接保存 model
:
full_model.pth
会包含 模型结构 + 训练参数。
2.1 加载完整模型¶
注意: - 这种方法适用于简单的模型,但依赖 Python 代码,如果代码环境变化(如不同版本 PyTorch),可能无法加载。 - 推荐保存
state_dict
,而不是整个模型,因为state_dict
更通用、兼容性更好。
3. 训练过程中定期保存模型¶
在训练时,我们通常希望 定期保存模型,例如每 10 个 epoch 保存一次:
for epoch in range(num_epochs):
train_one_epoch(model, optimizer) # 训练代码
if epoch % 10 == 0: # 每 10 轮保存一次
torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")
这样可以在训练中断后继续训练,或者选择最优的模型进行推理。
4. 保存和加载模型 + Optimizer(继续训练)¶
如果要继续训练,需要同时保存 模型参数 + 优化器状态:
4.1 保存模型和优化器¶
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss_value
}, "checkpoint.pth")
4.2 加载并继续训练¶
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch'] # 继续训练的起点
loss_value = checkpoint['loss']
model.train() # 切换回训练模式
总结¶
方法 | 代码 | 适用场景 | 优缺点 |
---|---|---|---|
仅保存参数(推荐) | torch.save(model.state_dict(), "model.pth") |
适用于大多数情况(推理、继续训练) | 需要手动定义模型结构,但兼容性好 |
保存完整模型 | torch.save(model, "full_model.pth") |
适用于小型项目或临时存储 | 依赖 Python 代码,可能导致兼容性问题 |
保存模型+优化器 | torch.save({...}, "checkpoint.pth") |
适用于断点续训 | 可以继续训练,但文件较大 |
如果你的目标是 部署推理,推荐 只保存 state_dict
,这样加载更灵活!😃
判断保存格式¶
在 PyTorch 中,.pth
文件通常是通过 torch.save()
保存的模型文件。它可能包含以下内容:
- 仅模型参数:通过
torch.save(model.state_dict(), 'model.pth')
保存。 - 模型结构和参数:通过
torch.save(model, 'model.pth')
保存。 - 模型、优化器和其他信息:通过
torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), ...}, 'model.pth')
保存。
以下是判断和打印结构的方法:
1. 检查文件内容¶
首先加载文件,查看其内容:
- 如果输出是
dict_keys(['state_dict'])
或类似内容layer1.weight
和layer1.bias
,说明只保存了模型参数。 - 如果输出包含
dict_keys(['model', 'optimizer', 'epoch', ...])
,说明保存了模型、优化器等信息。 - 如果输出是模型结构(如
OrderedDict
或torch.nn.Module
),说明保存了完整的模型结构和参数。
2. 打印模型结构¶
如果保存的是完整模型(结构和参数)¶
如果保存的是模型参数(state_dict
)¶
你需要先定义模型结构,然后加载参数:
import torch
import torch.nn as nn
# 假设你的模型类为 MyModel
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 实例化模型
model = MyModel()
# 加载参数
model.load_state_dict(torch.load('model.pth'))
# 打印模型结构
print(model)
网络结构发生变化(例如 input_channel 的值变了),加载参数时会失败。
RuntimeError: Error(s) in loading state_dict:
size mismatch for layer1.weight: copying a param with shape torch.Size([64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 3, 3]).
如果只有部分层的参数形状不匹配,可以手动加载匹配的参数,忽略不匹配的部分:
可能不止一个state_dict权重, 用户可能会保存ema_state_dict
之类的其他权重,选择需要的。
ema_state_dict
表示的是 Exponential Moving Average (EMA) 模型参数的字典。EMA 是一种用于模型权重平滑的技术,它通过计算模型参数的移动平均值来提高模型的泛化能力和稳定性。
!!! tip register_buffer
"
`register_buffer` 是 PyTorch 中 `nn.Module` 类的一个方法,用于注册缓冲区(buffer)。缓冲区是与模型参数类似但又不同的变量,它们的主要特点如下:
- **持久化存储**:缓冲区会被保存在模型的状态字典中,因此在**保存和加载模型**时也会一并保存和加载。
- **不参与梯度计算**:缓冲区不会被自动求导机制跟踪,即不会计算梯度,也不会更新通过优化器。
- **设备迁移**:缓冲区会随着模型一起迁移到不同的设备(如 CPU 或 GPU)。
```python
self.register_buffer('h',
torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
```
这段代码的作用是:
- 创建一个名为 `h` 的缓冲区,并将其值初始化为一个特定的张量。这个张量表示小波变换中的低通滤波器卷积核。
- 这个缓冲区会在模型保存和加载时被保留,并且不会参与反向传播的梯度计算。
总结来说,`register_buffer` 适用于那些需要持久化存储、但不需要进行梯度更新的变量,例如统计量、预定义的权重矩阵等。
3. 判断是否包含优化器¶
如果文件是字典形式,检查是否包含优化器的 state_dict
:
if 'optimizer' in checkpoint:
print("文件包含优化器状态")
optimizer_state = checkpoint['optimizer']
else:
print("文件不包含优化器状态")
4. 总结¶
- 如果文件是
state_dict
,你需要先定义模型结构,再加载参数。 - 如果文件是完整模型,可以直接加载并使用。
- 如果文件包含优化器状态,可以通过字典键值提取。