Pytorch – 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型
1 batch size对模型训练的影响
小的batch size引入的数据集的数据量较小,随机性越大,在部分情况下模型难以收敛,影响模型训练效率。
而在合理的范围内,越大的batch size本质上是对训练数据更优的一种选择,能够是梯度下降的方向更加准确,震荡越小,有利于收敛的稳定性。
但是如果batch size过大,超出了一个合理的范围,会限制模型的探索能力,出现局部最优的情况。
所以,在模型训练的过程中,bacth size太小和太大都不是一个好的选择。
2 为什么要使用梯度累积(gradient accumulation)的方案?
深度学习发展到今天,数据已经从图片,文本逐渐发展到了音频、视频这种更复杂,更需要高维度表示的数据。
试想一下,如果我们需要基于一个超百万个的视频数据集进行模型训练,需要使用Resnet50或者Resnet101作为backbone对每一个视频的视频帧提取特征并进行组合,假设每个视频有250帧,那么我们使用batch size=2进行训练,就等于我们需要对500张图片都进行Resnet50或者Resnet101进行计算,而如果我们只有一张3090(24G),那么在这种情况下24G的显存显然是不够用的。
在显卡的显存不够多,不足以支撑大的batch_size数据训练的情况下,而我们又想使用大的batch_size进行模型的训练,那么这个时候我们就可以使用梯度累计的方式进行优化,以防止显存爆炸。
3 在Pytorch中使用梯度累积
在Pytorch中反向传播梯度是不清零的,所以要实现梯度累积是比较简单的。
3.1 常用的训练模式
在Pytorch中,训练一个epoch训练的常用代码如下
for batch_idx, (input_id, label) in enumerate(train_loader):
# 1. 模型输出
pred = model(input_id)
loss = criterion(pred, label)
# 2. 反向传播
optimizer.zero_grad() # 梯度清空
loss.backward() # 反向传播,计算梯度
optimizer.step() # 根据梯度,更新网络参数
总结步骤如下:
- 计算loss,获取batch输入,计算model输出,通过损失函数计算loss
- optimizer.zero_grad() 清空之前的梯度
- loss.backward()反向传播,计算当前梯度
- optimizer.step()根据梯度更新网络参数
简单来说,就是进来一个batch的数据,计算一次梯度,更新一次网络。
3.2 梯度累积
我们将上述代码修改为梯度累积的方式
for batch_idx ,(input_id, label) in enumerate(train_loader):
# 1. 模型输出
pred = model(input_id)
loss = criterion(pred, label)
# 2.1 损失规范
loss = loss / accumulation_steps
# 2.2 反向传播,计算梯度
loss.backward()
if (batch_idx+1) % accumulation_steps == 0:
optimizer.step() # 更新参数
optimizer.zero_grad() # 梯度清空,为下一次反向传播做准备
总结步骤如下:
- 计算loss,获取batch输入,计算model输出,通过损失函数计算loss
- loss.backward()反向传播,计算当前梯度
- 多次循环步骤 1-2,不清空梯度,使梯度累加在已有梯度上;
- 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;
总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
在合理的范围内,batch size越大训练效果越好,梯度累积变相实现了batch_size的扩大,如果accumulation_steps
为8则batch size变相扩大了8倍,这是解决显存受限的很好的trick方式,不过在使用梯度累积的时候,学习率也要适当的放大。
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 梯度累积/梯度累加trick,在显存有限的情况下使用更大batch_size训练模型
原文链接:https://www.stubbornhuang.com/2444/
发布于:2022年12月09日 14:13:29
修改于:2023年06月21日 17:44:17
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50