跳转至

Pytorch 4 :Save & Load & Pretrain

导言

  • 保存与加载模型:学习如何保存训练好的模型,并在需要时加载模型进行推理或继续训练。
  • 迁移学习:学习如何使用预训练模型进行迁移学习,微调模型以适应新的任务。
  • 常用预训练模型:介绍PyTorch中常用的预训练模型,如ResNet、VGG等。

保存与读取

在 PyTorch 训练模型时,我们通常需要保存模型,以便后续继续训练或进行推理(Inference)。PyTorch 提供了两种常见的保存方式:

  1. 仅保存模型参数(推荐方式)
  2. 保存完整模型(包括结构和参数)

下面是详细的方法和代码示例。


1. 仅保存模型参数(推荐)

这种方式只保存模型的 state_dict(即模型的参数),但不包含模型结构。优点是更灵活,加载时可以创建相同结构的模型,再加载参数。

1.1 保存模型参数

import torch

# 假设 model 是你的神经网络
torch.save(model.state_dict(), "model.pth")
这样会把 model.pth 作为文件保存,其中仅包含模型的参数(权重和偏置)


1.2 加载模型参数

加载时需要先定义模型的结构,然后再加载 state_dict

import torch
import torch.nn as nn

# 重新定义模型结构(必须和保存时的模型结构一致)
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 20)

    def forward(self, x):
        return self.fc(x)

# 创建模型实例
model = MyModel()

# 加载保存的参数
model.load_state_dict(torch.load("model.pth"))

# 切换为评估模式(如果是推理)
model.eval()

注意: - 需要 手动创建模型结构,否则无法正确加载参数。 - model.eval() 让模型进入推理模式(影响 BatchNormDropout)。推理时必须调用 model.eval(),否则 Dropout 仍然生效,BN 计算方式也不同,可能导致预测结果不稳定!


2. 保存完整模型(包括结构+参数)

如果希望保存整个模型(包括结构和参数),可以直接保存 model

torch.save(model, "full_model.pth")
这样,full_model.pth 会包含 模型结构 + 训练参数

2.1 加载完整模型

model = torch.load("full_model.pth")
model.eval()  # 切换到推理模式

注意: - 这种方法适用于简单的模型,但依赖 Python 代码,如果代码环境变化(如不同版本 PyTorch),可能无法加载。 - 推荐保存 state_dict,而不是整个模型,因为 state_dict 更通用、兼容性更好。


3. 训练过程中定期保存模型

在训练时,我们通常希望 定期保存模型,例如每 10 个 epoch 保存一次:

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer)  # 训练代码

    if epoch % 10 == 0:  # 每 10 轮保存一次
        torch.save(model.state_dict(), f"model_epoch_{epoch}.pth")

这样可以在训练中断后继续训练,或者选择最优的模型进行推理。


4. 保存和加载模型 + Optimizer(继续训练)

如果要继续训练,需要同时保存 模型参数 + 优化器状态

4.1 保存模型和优化器

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss_value
}, "checkpoint.pth")

4.2 加载并继续训练

checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']  # 继续训练的起点
loss_value = checkpoint['loss']

model.train()  # 切换回训练模式

总结

方法 代码 适用场景 优缺点
仅保存参数(推荐) torch.save(model.state_dict(), "model.pth") 适用于大多数情况(推理、继续训练) 需要手动定义模型结构,但兼容性好
保存完整模型 torch.save(model, "full_model.pth") 适用于小型项目或临时存储 依赖 Python 代码,可能导致兼容性问题
保存模型+优化器 torch.save({...}, "checkpoint.pth") 适用于断点续训 可以继续训练,但文件较大

如果你的目标是 部署推理,推荐 只保存 state_dict,这样加载更灵活!😃

判断保存格式

在 PyTorch 中,.pth 文件通常是通过 torch.save() 保存的模型文件。它可能包含以下内容:

  1. 仅模型参数:通过 torch.save(model.state_dict(), 'model.pth') 保存。
  2. 模型结构和参数:通过 torch.save(model, 'model.pth') 保存。
  3. 模型、优化器和其他信息:通过 torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict(), ...}, 'model.pth') 保存。

以下是判断和打印结构的方法:


1. 检查文件内容

首先加载文件,查看其内容:

import torch

# 加载文件
checkpoint = torch.load('model.pth')

# 打印文件内容
print(checkpoint.keys())
  • 如果输出是 dict_keys(['state_dict']) 或类似内容layer1.weightlayer1.bias,说明只保存了模型参数。
  • 如果输出包含 dict_keys(['model', 'optimizer', 'epoch', ...]),说明保存了模型、优化器等信息。
  • 如果输出是模型结构(如 OrderedDicttorch.nn.Module),说明保存了完整的模型结构和参数。

2. 打印模型结构

如果保存的是完整模型(结构和参数)

import torch

# 加载模型
model = torch.load('model.pth')

# 打印模型结构
print(model)

如果保存的是模型参数(state_dict

你需要先定义模型结构,然后加载参数:

import torch
import torch.nn as nn

# 假设你的模型类为 MyModel
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# 实例化模型
model = MyModel()

# 加载参数
model.load_state_dict(torch.load('model.pth'))

# 打印模型结构
print(model)

网络结构发生变化(例如 input_channel 的值变了),加载参数时会失败。

RuntimeError: Error(s) in loading state_dict: 
size mismatch for layer1.weight: copying a param with shape torch.Size([64, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 256, 3, 3]).

如果只有部分层的参数形状不匹配,可以手动加载匹配的参数,忽略不匹配的部分:

checkpoint = torch.load('model.pth')
model_dict = model.state_dict()

# 过滤掉不匹配的参数
pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict and v.shape == model_dict[k].shape}

# 更新当前模型的参数
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

可能不止一个state_dict权重, 用户可能会保存ema_state_dict之类的其他权重,选择需要的。

ema_state_dict 表示的是 Exponential Moving Average (EMA) 模型参数的字典。EMA 是一种用于模型权重平滑的技术,它通过计算模型参数的移动平均值来提高模型的泛化能力和稳定性。

!!! tip register_buffer"

`register_buffer` 是 PyTorch 中 `nn.Module` 类的一个方法,用于注册缓冲区(buffer)。缓冲区是与模型参数类似但又不同的变量,它们的主要特点如下:

- **持久化存储**:缓冲区会被保存在模型的状态字典中,因此在**保存和加载模型**时也会一并保存和加载。
- **不参与梯度计算**:缓冲区不会被自动求导机制跟踪,即不会计算梯度,也不会更新通过优化器。
- **设备迁移**:缓冲区会随着模型一起迁移到不同的设备(如 CPU 或 GPU)。


```python
self.register_buffer('h',
    torch.tensor([[[1, 1], [1, 1]], [[1, 1], [1, 1]]]).view(1, 1, 2, 2, 2) * 0.3536
)
```

这段代码的作用是:

- 创建一个名为 `h` 的缓冲区,并将其值初始化为一个特定的张量。这个张量表示小波变换中的低通滤波器卷积核。
- 这个缓冲区会在模型保存和加载时被保留,并且不会参与反向传播的梯度计算。

总结来说,`register_buffer` 适用于那些需要持久化存储、但不需要进行梯度更新的变量,例如统计量、预定义的权重矩阵等。

3. 判断是否包含优化器

如果文件是字典形式,检查是否包含优化器的 state_dict

if 'optimizer' in checkpoint:
    print("文件包含优化器状态")
    optimizer_state = checkpoint['optimizer']
else:
    print("文件不包含优化器状态")

4. 总结

  • 如果文件是 state_dict,你需要先定义模型结构,再加载参数。
  • 如果文件是完整模型,可以直接加载并使用。
  • 如果文件包含优化器状态,可以通过字典键值提取。