Pytorch – 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题
1 显存爆炸的问题
最近使用以下示例代码测试自定义深度学习网络时耗光了所有显存,出现了梯度爆炸的问题。
model.eval()
for batch_idx, data in enumerate(tqdm(data_loader)):
image = data[0].to('cuda:0')
......
经过排查原因是没有加上with torch.no_grad()
语句停止梯度更新,从而导致了显存爆炸的问题,正确的示例代码如下
model.eval()
with torch.no_grad()
for batch_idx, data in enumerate(tqdm(data_loader)):
image = data[0].to('cuda:0')
......
2 model.train、model.eval和with torch.no_grad
model.train()
会将网络中的模块设置为训练模式,此时,如果神经网络中BN(batch normalization)层和Dropout层,那么这两个层将会起作用,防止网络出现过拟合的问题。
model.eval()
则会将网络设置为测试模式,此时,不会启用神经网络中的BN(batch normalization)层和Dropout层,model.eval()是保证BN层直接使用全部训练数据的均值和方差,即测试过程中要保证BN层的均值和方差不变。对于Dropout层,model.eval()是利用到了所有网络连接,即不进行随机舍弃神经元。
with torch.no_grad()
会将网络中Tensor的属性全部设置为False,并停止Autograd引擎,禁止梯度反向传播,以起到加速和节省显存的作用。它的作用是将该with语句包裹起来的部分停止梯度的更新,从而节省了GPU算力和显存,但是并不会影响dropout和BN层的行为。因此,测试的时候加上此语句也不会影响测试精度的,只是停止了梯度更新而已。在测试和验证阶段,使用with torch.no_grad()
会使得网络有更快的推理速度和内存使用,这使得我们在验证和测试网络时可以使用更大的batch size。
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 没有使用with torch.no_grad()造成测试网络时显存爆炸的问题
原文链接:https://www.stubbornhuang.com/2322/
发布于:2022年08月23日 10:29:14
修改于:2023年06月21日 18:13:30
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50