Pytorch – 模型保存与加载以及如何在已保存的模型的基础上继续训练模型
1 模型的保存和加载
1.1 保存与加载整个模型
保存网络的所有模块,代码量少。
但是这种方法缺点是保存模型的时候,序列化的数据被绑定到了特定的类和确切的目录。 这是因为pickle不保存模型类本身,而是保存这个类的路径, 并且在加载的时候会使用。因此, 当在其他项目里使用或者重构的时候,加载模型的时候会出错。
保存模型
import torch
torch.save(net,PATH)
加载模型
model=torch.load(PATH)
1.2 保存与加载模型中的参数
保存网络中的参数,速度快,占用空间少。
保存模型参数
torch.save(net.state_dict(),PATH)
加载模型参数
#定义模型结构
model=Model().cuda()
model.load_state_dict(torch.load(PATH))
1.3 保存与加载自定义模型
保存自定义模型参数
torch.save(
{
'epoch': epochID + 1,
'state_dict': model.state_dict(),
'best_loss': lossMIN,
'optimizer': optimizer.state_dict(),
'alpha': loss.alpha,
'gamma': loss.gamma
},
checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')
比如我们在上述代码中保存了epochID,模型的state_dict,min_loss,optimizer的state_dict等。
加载自定义模型
def load_checkpoint(model, checkpoint_PATH, optimizer):
if checkpoint != None:
model_CKPT = torch.load(checkpoint_PATH)
model.load_state_dict(model_CKPT['state_dict'])
print('loading checkpoint!')
optimizer.load_state_dict(model_CKPT['optimizer'])
return model, optimizer
如果对网络进行了增删查改,那么需要过滤一些旧的参数,那么加载代码修改为
def load_checkpoint(model, checkpoint, optimizer, loadOptimizer):
if checkpoint != 'No':
print("loading checkpoint...")
model_dict = model.state_dict()
modelCheckpoint = torch.load(checkpoint)
pretrained_dict = modelCheckpoint['state_dict']
# 过滤操作
new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()}
model_dict.update(new_dict)
# 打印出来,更新了多少的参数
print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict)))
model.load_state_dict(model_dict)
print("loaded finished!")
# 如果不需要更新优化器那么设置为false
if loadOptimizer == True:
optimizer.load_state_dict(modelCheckpoint['optimizer'])
print('loaded! optimizer')
else:
print('not loaded optimizer')
else:
print('No checkpoint is included')
return model, optimizer
1.4 在加载的模型上继续训练
在训练模型的时候可能会因为一些问题导致程序中断,或者常常需要观察训练情况的变化来更改学习率等参数,这时候就需要加载中断前保存的模型,并在此基础上继续训练
#-*- coding:utf-8 -*-
'''本文件用于举例说明pytorch保存和加载文件的方法'''
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.backends.cudnn as cudnn
import datetime
import argparse
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0 # dataloder线程数
test_flag = True #测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar' # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth' # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
# 构造模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(256 * 8 * 8, 1024)
self.fc2 = nn.Linear(1024, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool(F.relu(self.conv4(x)))
x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 模型训练
def train(model, train_loader, epoch):
model.train()
train_loss = 0
for i, data in enumerate(train_loader, 0):
x, y = data
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
y_hat = model(x)
loss = criterion(y_hat, y)
loss.backward()
optimizer.step()
train_loss += loss
loss_mean = train_loss / (i+1)
print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
# 模型测试
def test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
x, y = data
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
y_hat = model(x)
test_loss += criterion(y_hat, y).item()
pred = y_hat.max(1, keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
test_loss /= (i+1)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_data), 100. * correct / len(test_data)))
def main():
# 如果test_flag=True,则加载已保存的模型
if test_flag:
# 加载保存的模型直接进行测试机验证,不进行此模块以后的步骤
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epochs = checkpoint['epoch']
test(model, test_load)
return
for epoch in range(0, epochs):
train(model, train_load, epoch)
test(model, test_load)
# 保存模型
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, log_dir)
if __name__ == '__main__':
main()
上述代码文件是比较常规的训练模型与保存模型文件的代码,我们可以通过保存的模型参数以及修改上述代码的main函数使其具有加载离线模型文件并可继续训练的功能,修改后的代码文件如下
#-*- coding:utf-8 -*-
'''本文件用于举例说明pytorch保存和加载文件的方法'''
import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.backends.cudnn as cudnn
import datetime
import argparse
# 参数声明
batch_size = 32
epochs = 10
WORKERS = 0 # dataloder线程数
test_flag = True #测试标志,True时加载保存好的模型进行测试
ROOT = '/home/pxt/pytorch/cifar' # MNIST数据集保存路径
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth' # 模型保存路径
# 加载MNIST数据集
transform = tv.transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)
train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)
# 构造模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(256 * 8 * 8, 1024)
self.fc2 = nn.Linear(1024, 256)
self.fc3 = nn.Linear(256, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool(F.relu(self.conv2(x)))
x = F.relu(self.conv3(x))
x = self.pool(F.relu(self.conv4(x)))
x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
model = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 模型训练
def train(model, train_loader, epoch):
model.train()
train_loss = 0
for i, data in enumerate(train_loader, 0):
x, y = data
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
y_hat = model(x)
loss = criterion(y_hat, y)
loss.backward()
optimizer.step()
train_loss += loss
loss_mean = train_loss / (i+1)
print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))
# 模型测试
def test(model, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for i, data in enumerate(test_loader, 0):
x, y = data
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
y_hat = model(x)
test_loss += criterion(y_hat, y).item()
pred = y_hat.max(1, keepdim=True)[1]
correct += pred.eq(y.view_as(pred)).sum().item()
test_loss /= (i+1)
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_data), 100. * correct / len(test_data)))
def main():
# 如果test_flag=True,则加载已保存的模型
if test_flag:
# 加载保存的模型直接进行测试机验证,不进行此模块以后的步骤
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
test(model, test_load)
return
# 如果有保存的模型,则加载模型,并在其基础上继续训练
if os.path.exists(log_dir):
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch']
print('加载 epoch {} 成功!'.format(start_epoch))
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')
for epoch in range(start_epoch+1, epochs):
train(model, train_load, epoch)
test(model, test_load)
# 保存模型
state = {'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
torch.save(state, log_dir)
if __name__ == '__main__':
main()
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 模型保存与加载以及如何在已保存的模型的基础上继续训练模型
原文链接:https://www.stubbornhuang.com/2200/
发布于:2022年07月09日 8:31:03
修改于:2023年06月25日 20:59:04
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50