【Pytorch】模型的保存和加载

在对神经网络模型进行训练时,定期地保存模型可以增加程序的抗风险能力。同时,通过对保存模型的加载可以很方便地复现和使用我们训练好的神经网络模型。基于此,本文记录了 Pytorch 中的模型保存和加载方法。

Pytorch的模型后缀一般为 .pt.pth。在保存模型时,我们有两种选择:

  • 保存整个神经网络模型,包括神经网络的模型结构和模型参数
  • 只保存神经网络的模型参数

不同的保存方式对应不同的加载方式。

1. 只保存神经网络的模型参数

只保留神经网络的模型参数时,调用的是 torch.save() 函数,如:

1
2
model = ...
torch.save(model.state_dict(), "model.pt")

此时,在 model.pt 文件中将保存神经包含 weight 参数和 bias 参数的字典以反映神经网络模型的状态信息。也就是说,该方式只保存了神经网络中可以进行学习的参数。

在加载参数时,需要先建立好对应结构的神经网络模型,然后再进行调用,如:

1
2
model = ...
model.load_state_dict(torch.load("model.pt"))

2. 保存整个神经网络模型

在保存整个神经网络时,调用的也是 torch.save() 函数,如:

1
2
model = ...
torch.save(model, "model.pt")

而在加载模型时,并不需要先建立好对应的神经网络模型结构,而是可以直接赋值:

1
model = torch.load("model.pt")

3. 保存训练过程的 Checkpoint 以期继续训练

若是在训练途中保存 Checkpoint,仍可以通过调用 torch.save() 函数实现需求:

1
2
3
4
5
6
7
8
9
10
model = ...
torch.save({
"episode": episode,
"current_step": current_step,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": loss,
"reward": reward,
...
}, "model.pt")

此时,在加载时需要对对应内容进行索引:

1
2
3
4
5
6
7
8
9
10
model = ...
opt = ...
check_point = torch.load("model.pt")
model.load_state_dict(check_point["model_state_dict"])
opt.load_state_dict(check_point["optimizer_state_dict"])
episode = check_point["episode"]
current_step = check_point["current_step"]
loss = check_point["loss"]
reward = check_point["reward"]
...

此外,若是想在同一个文件里保存多个模型,也可以使用该方法保存,并通过索引读取不同的参数。

4. 读取模型后需要注意的事

在使用保存的模型进行预测时,调用后还需要加上:

1
model.eval()

同时,也需要关闭梯度的计算。此时的写法为:

1
2
3
4
5
6
7
8
9
10
model = ...
model.load_state_dict(torch.load("model.pt"))
# 或 model.load("model.pt")

model.eval()

with torch.no_grad():
...
predict_data = model(input_data)
...

同样的,在使用保存的模型进行训练时,也需要加上:

1
model.train()

将模型转换为训练模式。

上述操作的必要性在于,在训练阶段和测试阶段中 Dropout 层和 Batch Normalization 层将产生不同影响。若是不加控制,则在训练阶段和测试阶段难以得到期望结果。

对于 Dropout 层而言,在前向传播时,dropout 函数会让神经元的激活值以给定的概率停止工作(神经网络神经元的值置 0)以避免神经网络对局部特征的过度依赖,进而增强模型的泛化性。而在测试的时候,我们使用的是已经训练好的模型。在使用过程中显然不需要 dropout 的功能。同时,如果 dropout 功能仍然存在,将不可避免地会影响到最终输出结果的准确性。

对于 Batch Normalization 层而言,Batch Normalization 是为了让神经网络每层的数据输入都保持在相近的范围:

其中,模型是通过计算输入数据的均值和方差来实现 Batch Normalization。而在使用模型的时候,对输入的单个数据或需要预测的一批数据计算均值和方差是没有意义的,而应该直接使用训练阶段得到的对整体样本空间估算的均值和方差进行计算。

打赏
  • 版权声明: 本博客所有文章除特别声明外,著作权归作者所有。转载请注明出处!
  • Copyrights © 2022-2024 lgc0208@foxmail.com
  • 访问人数: | 浏览次数:

请我喝杯咖啡吧~

支付宝
微信