Pytorch – 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0
1 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0
Pytroch在实现断点续训功能时,在保存模型文件时,需要同时保存model、optimizer、lr_scheduler的state_dict,比如
torch.save({
'epoch': epoch,
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.lr_scheduler.state_dict(),
}, model_save_path)
然后在加载模型时,除了加载模型的权重之外,还需要同时加载optimizer和lr_scheduler的权重,比如
model_weights = modified_weights(check_point_state_dict['model_state_dict'])
optimizer.load_state_dict(check_point_state_dict["optimizer_state_dict"])
lr_scheduler.load_state_dict(check_point_state_dict["scheduler_state_dict"])
这个时候比较容易犯的错误是,optimizer默认是在cpu上加载权重的,而我们之后继续训练模型时都是在GPU上进行了,所以如果optimizer没有任何修改,则会出在optimizer.step()
执行时出现
RuntimeError: Expected all tensors to be on the same device, but found cuda:0
其实际上就是optimizer的权重没有在GPU上,所以解决方法就是将optimizer的权重转移到GPU上,示例代码如下
optimizer.load_state_dict(check_point_state_dict["optimizer_state_dict"])
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.to(self.output_device)
其中self.output_device
就是项目中的GPU索引号。
修改完成之后,错误解决。
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 模型断点续训,optimizer.step()报错:RuntimeError Expected all tensors to be on the same device, but found cuda:0
原文链接:https://www.stubbornhuang.com/2603/
发布于:2023年05月08日 11:03:42
修改于:2023年05月08日 11:04:20
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50