Pytorch – 手动调整学习率以及使用torch.optim.lr_scheduler调整学习率
1 Pytorch中学习率的调整方法
在Pytorch中调整训练过程中的学习率可以有两种方式:
- 一种是利用torch.optim.lr_scheduler提供的学习调整的方法
- 另一种是手动调整学习率,即调整优化器参数组中的lr参数
1.1 手动调整学习率
手动调整学习率的代码如下,我们在函数adjust_learn_rata
中根据epoch设置warm_up的学习率预热阶段,以及在预热完成之后每隔3个epoch将学习率下降为之前的90%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
def adjust_learn_rata(epoch,optimizer,base_lr):
warm_up_epochs = 20
if epoch < warm_up_epochs:
lr = base_lr * (epoch + 1) / warm_up_epochs
else:
lr = optimizer.state_dict()['param_groups'][0]['lr']
if epoch % 3 == 0:
lr = optimizer.state_dict()['param_groups'][0]['lr'] * 0.9
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.99))
lr_list = []
base_lr = 0.0001
epochs = 100
for epoch in range(epochs):
optimizer.zero_grad()
optimizer.step()
adjust_learn_rata(epoch,optimizer,base_lr)
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("learn rata")
plt.plot(range(100), lr_list, color='r')
plt.show()
学习率变化曲线如下:
1.2 torch.optim.lr_scheduler调整学习率
torch.optim.lr_scheduler
中提供了以下调整学习率的方法
- torch.optim.lr_scheduler.LambdaLR
- torch.optim.lr_scheduler.MultiplicativeLR
- torch.optim.lr_scheduler.StepLR
- torch.optim.lr_scheduler.MultiStepLR
- torch.optim.lr_scheduler.ExponentialLR
- torch.optim.lr_scheduler.CosineAnnealingLR
- torch.optim.lr_scheduler.ReduceLROnPlateau
- torch.optim.lr_scheduler.CyclicLR
- torch.optim.lr_scheduler.OneCycleLR
- torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
下面我们将介绍上述的学习率的调整方法以及各自调整过程中学习率的变化情况。
2 torch.optim.lr_scheduler
torch.optim.lr_scheduler
提供了几种根据训练epoch数来调整学习率的方式,而torch.optim.lr_scheduler.ReduceLROnPlateau
可以根据一些测量策略动态的降低学习率。
2.1 torch.optim.lr_scheduler.LambdaLR
1. 类形式
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
2. 学习率更新策略
其中newlr为新的学习率,initiallr为初始化学习率,\lambda是通过参数lr_lambda
和epoch
计算得到的。
3. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- lr_lambda:函数或者是函数列表。根据epoch计算乘法因子\lambda的函数;或者是一个函数列表,分别计算各个parameter groups的学习率更新用到的乘法因子\lambda;
- last_epoch:int,默认值为-1。最后一个epoch的index,如果是训练了很多个epoch后中断了,继续训练,这个值就等于加载的模型的epoch。默认为-1表示从头开始训练,即从epoch=1开始;
- verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
4. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
def adjust_learnrate(epoch):
lamda = 0.95 ** epoch
print('lamda = {}'.format(lamda))
return lamda
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=adjust_learnrate)
lr_list = []
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
scheduler.step()
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
lamda = 1.0
Epoch:0 , learn rate = 0.0001
lamda = 0.95
Epoch:1 , learn rate = 9.5e-05
lamda = 0.9025
Epoch:2 , learn rate = 9.025e-05
lamda = 0.8573749999999999
Epoch:3 , learn rate = 8.573749999999999e-05
lamda = 0.8145062499999999
Epoch:4 , learn rate = 8.1450625e-05
lamda = 0.7737809374999998
Epoch:5 , learn rate = 7.737809374999998e-05
lamda = 0.7350918906249998
Epoch:6 , learn rate = 7.350918906249998e-05
lamda = 0.6983372960937497
Epoch:7 , learn rate = 6.983372960937497e-05
lamda = 0.6634204312890623
Epoch:8 , learn rate = 6.634204312890623e-05
lamda = 0.6302494097246091
Epoch:9 , learn rate = 6.30249409724609e-05
lamda = 0.5987369392383787
Epoch:10 , learn rate = 5.987369392383787e-05
lamda = 0.5688000922764597
Epoch:11 , learn rate = 5.688000922764597e-05
lamda = 0.5403600876626367
Epoch:12 , learn rate = 5.403600876626367e-05
lamda = 0.5133420832795048
Epoch:13 , learn rate = 5.1334208327950485e-05
lamda = 0.48767497911552954
Epoch:14 , learn rate = 4.876749791155295e-05
lamda = 0.46329123015975304
Epoch:15 , learn rate = 4.6329123015975305e-05
lamda = 0.44012666865176536
Epoch:16 , learn rate = 4.4012666865176535e-05
lamda = 0.4181203352191771
Epoch:17 , learn rate = 4.181203352191771e-05
lamda = 0.3972143184582182
Epoch:18 , learn rate = 3.972143184582182e-05
lamda = 0.37735360253530725
Epoch:19 , learn rate = 3.7735360253530726e-05
lamda = 0.3584859224085419
Epoch:20 , learn rate = 3.584859224085419e-05
lamda = 0.3405616262881148
Epoch:21 , learn rate = 3.405616262881148e-05
lamda = 0.323533544973709
Epoch:22 , learn rate = 3.2353354497370904e-05
lamda = 0.3073568677250236
Epoch:23 , learn rate = 3.073568677250236e-05
lamda = 0.2919890243387724
Epoch:24 , learn rate = 2.919890243387724e-05
lamda = 0.27738957312183377
Epoch:25 , learn rate = 2.7738957312183377e-05
lamda = 0.26352009446574204
Epoch:26 , learn rate = 2.6352009446574204e-05
lamda = 0.2503440897424549
Epoch:27 , learn rate = 2.5034408974245492e-05
lamda = 0.23782688525533216
Epoch:28 , learn rate = 2.3782688525533216e-05
lamda = 0.22593554099256555
Epoch:29 , learn rate = 2.2593554099256555e-05
lamda = 0.21463876394293727
Epoch:30 , learn rate = 2.146387639429373e-05
lamda = 0.2039068257457904
Epoch:31 , learn rate = 2.039068257457904e-05
lamda = 0.19371148445850087
Epoch:32 , learn rate = 1.9371148445850086e-05
lamda = 0.18402591023557582
Epoch:33 , learn rate = 1.8402591023557583e-05
lamda = 0.174824614723797
Epoch:34 , learn rate = 1.74824614723797e-05
lamda = 0.16608338398760716
Epoch:35 , learn rate = 1.6608338398760718e-05
lamda = 0.1577792147882268
Epoch:36 , learn rate = 1.577792147882268e-05
lamda = 0.14989025404881545
Epoch:37 , learn rate = 1.4989025404881546e-05
lamda = 0.14239574134637467
Epoch:38 , learn rate = 1.4239574134637468e-05
lamda = 0.13527595427905592
Epoch:39 , learn rate = 1.3527595427905592e-05
lamda = 0.12851215656510312
Epoch:40 , learn rate = 1.2851215656510312e-05
lamda = 0.12208654873684796
Epoch:41 , learn rate = 1.2208654873684796e-05
lamda = 0.11598222130000556
Epoch:42 , learn rate = 1.1598222130000557e-05
lamda = 0.11018311023500528
Epoch:43 , learn rate = 1.1018311023500529e-05
lamda = 0.10467395472325501
Epoch:44 , learn rate = 1.04673954723255e-05
lamda = 0.09944025698709225
Epoch:45 , learn rate = 9.944025698709225e-06
lamda = 0.09446824413773763
Epoch:46 , learn rate = 9.446824413773763e-06
lamda = 0.08974483193085075
Epoch:47 , learn rate = 8.974483193085076e-06
lamda = 0.0852575903343082
Epoch:48 , learn rate = 8.52575903343082e-06
lamda = 0.0809947108175928
Epoch:49 , learn rate = 8.09947108175928e-06
lamda = 0.07694497527671315
Epoch:50 , learn rate = 7.694497527671315e-06
lamda = 0.07309772651287749
Epoch:51 , learn rate = 7.30977265128775e-06
lamda = 0.06944284018723361
Epoch:52 , learn rate = 6.944284018723361e-06
lamda = 0.06597069817787193
Epoch:53 , learn rate = 6.5970698177871935e-06
lamda = 0.06267216326897833
Epoch:54 , learn rate = 6.267216326897833e-06
lamda = 0.05953855510552941
Epoch:55 , learn rate = 5.953855510552941e-06
lamda = 0.056561627350252934
Epoch:56 , learn rate = 5.6561627350252934e-06
lamda = 0.053733545982740286
Epoch:57 , learn rate = 5.373354598274029e-06
lamda = 0.051046868683603266
Epoch:58 , learn rate = 5.104686868360327e-06
lamda = 0.048494525249423104
Epoch:59 , learn rate = 4.849452524942311e-06
lamda = 0.046069798986951946
Epoch:60 , learn rate = 4.6069798986951945e-06
lamda = 0.043766309037604346
Epoch:61 , learn rate = 4.3766309037604346e-06
lamda = 0.04157799358572413
Epoch:62 , learn rate = 4.157799358572413e-06
lamda = 0.03949909390643792
Epoch:63 , learn rate = 3.949909390643792e-06
lamda = 0.03752413921111602
Epoch:64 , learn rate = 3.752413921111602e-06
lamda = 0.03564793225056022
Epoch:65 , learn rate = 3.564793225056022e-06
lamda = 0.033865535638032206
Epoch:66 , learn rate = 3.3865535638032207e-06
lamda = 0.03217225885613059
Epoch:67 , learn rate = 3.2172258856130592e-06
lamda = 0.030563645913324063
Epoch:68 , learn rate = 3.0563645913324063e-06
lamda = 0.02903546361765786
Epoch:69 , learn rate = 2.903546361765786e-06
lamda = 0.027583690436774964
Epoch:70 , learn rate = 2.7583690436774965e-06
lamda = 0.026204505914936217
Epoch:71 , learn rate = 2.6204505914936218e-06
lamda = 0.024894280619189402
Epoch:72 , learn rate = 2.4894280619189404e-06
lamda = 0.023649566588229934
Epoch:73 , learn rate = 2.3649566588229936e-06
lamda = 0.022467088258818435
Epoch:74 , learn rate = 2.2467088258818436e-06
lamda = 0.02134373384587751
Epoch:75 , learn rate = 2.1343733845877514e-06
lamda = 0.020276547153583634
Epoch:76 , learn rate = 2.0276547153583634e-06
lamda = 0.01926271979590445
Epoch:77 , learn rate = 1.9262719795904453e-06
lamda = 0.01829958380610923
Epoch:78 , learn rate = 1.829958380610923e-06
lamda = 0.017384604615803767
Epoch:79 , learn rate = 1.7384604615803768e-06
lamda = 0.01651537438501358
Epoch:80 , learn rate = 1.651537438501358e-06
lamda = 0.0156896056657629
Epoch:81 , learn rate = 1.56896056657629e-06
lamda = 0.014905125382474753
Epoch:82 , learn rate = 1.4905125382474754e-06
lamda = 0.014159869113351015
Epoch:83 , learn rate = 1.4159869113351016e-06
lamda = 0.013451875657683464
Epoch:84 , learn rate = 1.3451875657683465e-06
lamda = 0.012779281874799289
Epoch:85 , learn rate = 1.2779281874799288e-06
lamda = 0.012140317781059324
Epoch:86 , learn rate = 1.2140317781059325e-06
lamda = 0.011533301892006357
Epoch:87 , learn rate = 1.1533301892006358e-06
lamda = 0.01095663679740604
Epoch:88 , learn rate = 1.095663679740604e-06
lamda = 0.010408804957535737
Epoch:89 , learn rate = 1.0408804957535738e-06
lamda = 0.00988836470965895
Epoch:90 , learn rate = 9.88836470965895e-07
lamda = 0.009393946474176
Epoch:91 , learn rate = 9.393946474176001e-07
lamda = 0.008924249150467202
Epoch:92 , learn rate = 8.924249150467202e-07
lamda = 0.00847803669294384
Epoch:93 , learn rate = 8.478036692943841e-07
lamda = 0.008054134858296648
Epoch:94 , learn rate = 8.054134858296649e-07
lamda = 0.007651428115381815
Epoch:95 , learn rate = 7.651428115381816e-07
lamda = 0.007268856709612724
Epoch:96 , learn rate = 7.268856709612725e-07
lamda = 0.006905413874132088
Epoch:97 , learn rate = 6.905413874132088e-07
lamda = 0.006560143180425483
Epoch:98 , learn rate = 6.560143180425483e-07
lamda = 0.006232136021404208
Epoch:99 , learn rate = 6.232136021404208e-07
lamda = 0.0059205292203339975
学习率变化曲线如下:
2.2 torch.optim.lr_scheduler.MultiplicativeLR
每个参数组的学习率为每个参数组的学习率与指定lambda函数中给定的因子的乘积。
此方法与2.1节中的orch.optim.lr_scheduler.LambdaLR
的区别是,orch.optim.lr_scheduler.LambdaLR
是乘积因子与初始化学习率进行相乘,而此方法是乘积因子与当前参数组的学习率相乘。
1. 类形式
torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)
2. 学习率更新策略
其中newlr为新的学习率,currentlr为当前学习率,\lambda是通过参数lr_lambda
和epoch
计算得到的因子。
3. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- lr_lambda:函数或者是函数列表。根据epoch计算乘法因子\lambda的函数;或者是一个函数列表,分别计算各个parameter groups的学习率更新用到的乘法因子\lambda;
- last_epoch:int,默认值为-1。最后一个epoch的index,如果是训练了很多个epoch后中断了,继续训练,这个值就等于加载的模型的epoch。默认为-1表示从头开始训练,即从epoch=1开始;
- verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
4. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
def adjust_learnrate(epoch):
lamda = 0.95
return lamda
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(optimizer,lr_lambda=adjust_learnrate)
lr_list = []
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
scheduler.step()
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
Epoch:0 , learn rate = 0.0001
Epoch:1 , learn rate = 9.5e-05
Epoch:2 , learn rate = 9.025e-05
Epoch:3 , learn rate = 8.573749999999999e-05
Epoch:4 , learn rate = 8.145062499999998e-05
Epoch:5 , learn rate = 7.737809374999998e-05
Epoch:6 , learn rate = 7.350918906249998e-05
Epoch:7 , learn rate = 6.983372960937497e-05
Epoch:8 , learn rate = 6.634204312890622e-05
Epoch:9 , learn rate = 6.30249409724609e-05
Epoch:10 , learn rate = 5.987369392383786e-05
Epoch:11 , learn rate = 5.688000922764596e-05
Epoch:12 , learn rate = 5.4036008766263664e-05
Epoch:13 , learn rate = 5.133420832795048e-05
Epoch:14 , learn rate = 4.876749791155295e-05
Epoch:15 , learn rate = 4.6329123015975305e-05
Epoch:16 , learn rate = 4.4012666865176535e-05
Epoch:17 , learn rate = 4.181203352191771e-05
Epoch:18 , learn rate = 3.972143184582182e-05
Epoch:19 , learn rate = 3.7735360253530726e-05
Epoch:20 , learn rate = 3.584859224085419e-05
Epoch:21 , learn rate = 3.405616262881148e-05
Epoch:22 , learn rate = 3.2353354497370904e-05
Epoch:23 , learn rate = 3.0735686772502355e-05
Epoch:24 , learn rate = 2.9198902433877236e-05
Epoch:25 , learn rate = 2.7738957312183373e-05
Epoch:26 , learn rate = 2.6352009446574204e-05
Epoch:27 , learn rate = 2.5034408974245492e-05
Epoch:28 , learn rate = 2.3782688525533216e-05
Epoch:29 , learn rate = 2.2593554099256555e-05
Epoch:30 , learn rate = 2.1463876394293726e-05
Epoch:31 , learn rate = 2.039068257457904e-05
Epoch:32 , learn rate = 1.9371148445850086e-05
Epoch:33 , learn rate = 1.840259102355758e-05
Epoch:34 , learn rate = 1.74824614723797e-05
Epoch:35 , learn rate = 1.6608338398760715e-05
Epoch:36 , learn rate = 1.5777921478822678e-05
Epoch:37 , learn rate = 1.4989025404881544e-05
Epoch:38 , learn rate = 1.4239574134637466e-05
Epoch:39 , learn rate = 1.3527595427905592e-05
Epoch:40 , learn rate = 1.2851215656510312e-05
Epoch:41 , learn rate = 1.2208654873684796e-05
Epoch:42 , learn rate = 1.1598222130000555e-05
Epoch:43 , learn rate = 1.1018311023500527e-05
Epoch:44 , learn rate = 1.04673954723255e-05
Epoch:45 , learn rate = 9.944025698709225e-06
Epoch:46 , learn rate = 9.446824413773763e-06
Epoch:47 , learn rate = 8.974483193085074e-06
Epoch:48 , learn rate = 8.52575903343082e-06
Epoch:49 , learn rate = 8.09947108175928e-06
Epoch:50 , learn rate = 7.694497527671315e-06
Epoch:51 , learn rate = 7.309772651287749e-06
Epoch:52 , learn rate = 6.944284018723361e-06
Epoch:53 , learn rate = 6.597069817787193e-06
Epoch:54 , learn rate = 6.267216326897833e-06
Epoch:55 , learn rate = 5.953855510552941e-06
Epoch:56 , learn rate = 5.6561627350252934e-06
Epoch:57 , learn rate = 5.373354598274029e-06
Epoch:58 , learn rate = 5.104686868360327e-06
Epoch:59 , learn rate = 4.84945252494231e-06
Epoch:60 , learn rate = 4.6069798986951945e-06
Epoch:61 , learn rate = 4.3766309037604346e-06
Epoch:62 , learn rate = 4.157799358572413e-06
Epoch:63 , learn rate = 3.949909390643792e-06
Epoch:64 , learn rate = 3.7524139211116024e-06
Epoch:65 , learn rate = 3.564793225056022e-06
Epoch:66 , learn rate = 3.3865535638032207e-06
Epoch:67 , learn rate = 3.2172258856130596e-06
Epoch:68 , learn rate = 3.0563645913324067e-06
Epoch:69 , learn rate = 2.903546361765786e-06
Epoch:70 , learn rate = 2.7583690436774965e-06
Epoch:71 , learn rate = 2.6204505914936218e-06
Epoch:72 , learn rate = 2.4894280619189404e-06
Epoch:73 , learn rate = 2.3649566588229932e-06
Epoch:74 , learn rate = 2.2467088258818436e-06
Epoch:75 , learn rate = 2.1343733845877514e-06
Epoch:76 , learn rate = 2.027654715358364e-06
Epoch:77 , learn rate = 1.9262719795904457e-06
Epoch:78 , learn rate = 1.8299583806109232e-06
Epoch:79 , learn rate = 1.738460461580377e-06
Epoch:80 , learn rate = 1.651537438501358e-06
Epoch:81 , learn rate = 1.5689605665762901e-06
Epoch:82 , learn rate = 1.4905125382474756e-06
Epoch:83 , learn rate = 1.4159869113351018e-06
Epoch:84 , learn rate = 1.3451875657683467e-06
Epoch:85 , learn rate = 1.2779281874799292e-06
Epoch:86 , learn rate = 1.2140317781059327e-06
Epoch:87 , learn rate = 1.153330189200636e-06
Epoch:88 , learn rate = 1.0956636797406041e-06
Epoch:89 , learn rate = 1.0408804957535738e-06
Epoch:90 , learn rate = 9.88836470965895e-07
Epoch:91 , learn rate = 9.393946474176003e-07
Epoch:92 , learn rate = 8.924249150467202e-07
Epoch:93 , learn rate = 8.478036692943841e-07
Epoch:94 , learn rate = 8.054134858296649e-07
Epoch:95 , learn rate = 7.651428115381816e-07
Epoch:96 , learn rate = 7.268856709612725e-07
Epoch:97 , learn rate = 6.905413874132088e-07
Epoch:98 , learn rate = 6.560143180425484e-07
Epoch:99 , learn rate = 6.232136021404209e-07
学习率变化曲线如下:
2.3 torch.optim.lr_scheduler.StepLR
每经过epoch_size个epoch,就通过乘以参数gamma对学习率进行衰减
1. 类形式
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1, verbose=False)
2. 学习率更新策略
newlr为更新后的学习率,initiallr为初始化学习率,epoch为当前的epoch数,stepsize为步长,\gamma为学习率衰减参数gamma。
3. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- step_size:int型。学习率衰减间隔;
- gamma:float,默认值为0.1。学习率衰减的乘法因子;
- last_epoch:int,默认值为-1。最后一个epoch的index,如果是训练了很多个epoch后中断了,继续训练,这个值就等于加载的模型的epoch。默认为-1表示从头开始训练,即从epoch=1开始;
- verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
4. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.3)
lr_list = []
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
scheduler.step()
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
Epoch:0 , learn rate = 0.0001
Epoch:1 , learn rate = 0.0001
Epoch:2 , learn rate = 0.0001
Epoch:3 , learn rate = 0.0001
Epoch:4 , learn rate = 0.0001
Epoch:5 , learn rate = 0.0001
Epoch:6 , learn rate = 0.0001
Epoch:7 , learn rate = 0.0001
Epoch:8 , learn rate = 0.0001
Epoch:9 , learn rate = 0.0001
Epoch:10 , learn rate = 3e-05
Epoch:11 , learn rate = 3e-05
Epoch:12 , learn rate = 3e-05
Epoch:13 , learn rate = 3e-05
Epoch:14 , learn rate = 3e-05
Epoch:15 , learn rate = 3e-05
Epoch:16 , learn rate = 3e-05
Epoch:17 , learn rate = 3e-05
Epoch:18 , learn rate = 3e-05
Epoch:19 , learn rate = 3e-05
Epoch:20 , learn rate = 9e-06
Epoch:21 , learn rate = 9e-06
Epoch:22 , learn rate = 9e-06
Epoch:23 , learn rate = 9e-06
Epoch:24 , learn rate = 9e-06
Epoch:25 , learn rate = 9e-06
Epoch:26 , learn rate = 9e-06
Epoch:27 , learn rate = 9e-06
Epoch:28 , learn rate = 9e-06
Epoch:29 , learn rate = 9e-06
Epoch:30 , learn rate = 2.7e-06
Epoch:31 , learn rate = 2.7e-06
Epoch:32 , learn rate = 2.7e-06
Epoch:33 , learn rate = 2.7e-06
Epoch:34 , learn rate = 2.7e-06
Epoch:35 , learn rate = 2.7e-06
Epoch:36 , learn rate = 2.7e-06
Epoch:37 , learn rate = 2.7e-06
Epoch:38 , learn rate = 2.7e-06
Epoch:39 , learn rate = 2.7e-06
Epoch:40 , learn rate = 8.1e-07
Epoch:41 , learn rate = 8.1e-07
Epoch:42 , learn rate = 8.1e-07
Epoch:43 , learn rate = 8.1e-07
Epoch:44 , learn rate = 8.1e-07
Epoch:45 , learn rate = 8.1e-07
Epoch:46 , learn rate = 8.1e-07
Epoch:47 , learn rate = 8.1e-07
Epoch:48 , learn rate = 8.1e-07
Epoch:49 , learn rate = 8.1e-07
Epoch:50 , learn rate = 2.43e-07
Epoch:51 , learn rate = 2.43e-07
Epoch:52 , learn rate = 2.43e-07
Epoch:53 , learn rate = 2.43e-07
Epoch:54 , learn rate = 2.43e-07
Epoch:55 , learn rate = 2.43e-07
Epoch:56 , learn rate = 2.43e-07
Epoch:57 , learn rate = 2.43e-07
Epoch:58 , learn rate = 2.43e-07
Epoch:59 , learn rate = 2.43e-07
Epoch:60 , learn rate = 7.29e-08
Epoch:61 , learn rate = 7.29e-08
Epoch:62 , learn rate = 7.29e-08
Epoch:63 , learn rate = 7.29e-08
Epoch:64 , learn rate = 7.29e-08
Epoch:65 , learn rate = 7.29e-08
Epoch:66 , learn rate = 7.29e-08
Epoch:67 , learn rate = 7.29e-08
Epoch:68 , learn rate = 7.29e-08
Epoch:69 , learn rate = 7.29e-08
Epoch:70 , learn rate = 2.187e-08
Epoch:71 , learn rate = 2.187e-08
Epoch:72 , learn rate = 2.187e-08
Epoch:73 , learn rate = 2.187e-08
Epoch:74 , learn rate = 2.187e-08
Epoch:75 , learn rate = 2.187e-08
Epoch:76 , learn rate = 2.187e-08
Epoch:77 , learn rate = 2.187e-08
Epoch:78 , learn rate = 2.187e-08
Epoch:79 , learn rate = 2.187e-08
Epoch:80 , learn rate = 6.561e-09
Epoch:81 , learn rate = 6.561e-09
Epoch:82 , learn rate = 6.561e-09
Epoch:83 , learn rate = 6.561e-09
Epoch:84 , learn rate = 6.561e-09
Epoch:85 , learn rate = 6.561e-09
Epoch:86 , learn rate = 6.561e-09
Epoch:87 , learn rate = 6.561e-09
Epoch:88 , learn rate = 6.561e-09
Epoch:89 , learn rate = 6.561e-09
Epoch:90 , learn rate = 1.9683e-09
Epoch:91 , learn rate = 1.9683e-09
Epoch:92 , learn rate = 1.9683e-09
Epoch:93 , learn rate = 1.9683e-09
Epoch:94 , learn rate = 1.9683e-09
Epoch:95 , learn rate = 1.9683e-09
Epoch:96 , learn rate = 1.9683e-09
Epoch:97 , learn rate = 1.9683e-09
Epoch:98 , learn rate = 1.9683e-09
Epoch:99 , learn rate = 1.9683e-09
学习率变化曲线如下:
2.4 torch.optim.lr_scheduler.MultiStepLR
每次遇到参数milestones
中的epoch,就通过乘以参数gamma对学习率进行衰减
1. 类形式
torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)
2. 学习率更新策略
newlr为更新后的学习率,initiallr为初始化学习率,epoch为当前的epoch数,bisecright为bisect模块中的bisect_right
函数,其返回值是epoch插入排序好的列表milestones的位置,如
import bisect
li = [1, 23, 45, 12, 23, 42, 54, 123, 14, 52, 3]
li.sort()
print(li)
print(bisect.bisect(li, 3))
输出
[1, 3, 12, 14, 23, 23, 42, 45, 52, 54, 123]
2
3. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- milestones:list,必须是以升序排列的epoch索引,如[10,20,30];
- gamma:float,默认值为0.1。学习率衰减的乘法因子;
- last_epoch:int,默认值为-1。最后一个epoch的index,如果是训练了很多个epoch后中断了,继续训练,这个值就等于加载的模型的epoch。默认为-1表示从头开始训练,即从epoch=1开始;
- verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
4. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(),lr=0.0001,betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,30,50,70],gamma=0.3)
lr_list = []
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
scheduler.step()
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
Epoch:0 , learn rate = 0.0001
Epoch:1 , learn rate = 0.0001
Epoch:2 , learn rate = 0.0001
Epoch:3 , learn rate = 0.0001
Epoch:4 , learn rate = 0.0001
Epoch:5 , learn rate = 0.0001
Epoch:6 , learn rate = 0.0001
Epoch:7 , learn rate = 0.0001
Epoch:8 , learn rate = 0.0001
Epoch:9 , learn rate = 0.0001
Epoch:10 , learn rate = 3e-05
Epoch:11 , learn rate = 3e-05
Epoch:12 , learn rate = 3e-05
Epoch:13 , learn rate = 3e-05
Epoch:14 , learn rate = 3e-05
Epoch:15 , learn rate = 3e-05
Epoch:16 , learn rate = 3e-05
Epoch:17 , learn rate = 3e-05
Epoch:18 , learn rate = 3e-05
Epoch:19 , learn rate = 3e-05
Epoch:20 , learn rate = 3e-05
Epoch:21 , learn rate = 3e-05
Epoch:22 , learn rate = 3e-05
Epoch:23 , learn rate = 3e-05
Epoch:24 , learn rate = 3e-05
Epoch:25 , learn rate = 3e-05
Epoch:26 , learn rate = 3e-05
Epoch:27 , learn rate = 3e-05
Epoch:28 , learn rate = 3e-05
Epoch:29 , learn rate = 3e-05
Epoch:30 , learn rate = 9e-06
Epoch:31 , learn rate = 9e-06
Epoch:32 , learn rate = 9e-06
Epoch:33 , learn rate = 9e-06
Epoch:34 , learn rate = 9e-06
Epoch:35 , learn rate = 9e-06
Epoch:36 , learn rate = 9e-06
Epoch:37 , learn rate = 9e-06
Epoch:38 , learn rate = 9e-06
Epoch:39 , learn rate = 9e-06
Epoch:40 , learn rate = 9e-06
Epoch:41 , learn rate = 9e-06
Epoch:42 , learn rate = 9e-06
Epoch:43 , learn rate = 9e-06
Epoch:44 , learn rate = 9e-06
Epoch:45 , learn rate = 9e-06
Epoch:46 , learn rate = 9e-06
Epoch:47 , learn rate = 9e-06
Epoch:48 , learn rate = 9e-06
Epoch:49 , learn rate = 9e-06
Epoch:50 , learn rate = 2.7e-06
Epoch:51 , learn rate = 2.7e-06
Epoch:52 , learn rate = 2.7e-06
Epoch:53 , learn rate = 2.7e-06
Epoch:54 , learn rate = 2.7e-06
Epoch:55 , learn rate = 2.7e-06
Epoch:56 , learn rate = 2.7e-06
Epoch:57 , learn rate = 2.7e-06
Epoch:58 , learn rate = 2.7e-06
Epoch:59 , learn rate = 2.7e-06
Epoch:60 , learn rate = 2.7e-06
Epoch:61 , learn rate = 2.7e-06
Epoch:62 , learn rate = 2.7e-06
Epoch:63 , learn rate = 2.7e-06
Epoch:64 , learn rate = 2.7e-06
Epoch:65 , learn rate = 2.7e-06
Epoch:66 , learn rate = 2.7e-06
Epoch:67 , learn rate = 2.7e-06
Epoch:68 , learn rate = 2.7e-06
Epoch:69 , learn rate = 2.7e-06
Epoch:70 , learn rate = 8.1e-07
Epoch:71 , learn rate = 8.1e-07
Epoch:72 , learn rate = 8.1e-07
Epoch:73 , learn rate = 8.1e-07
Epoch:74 , learn rate = 8.1e-07
Epoch:75 , learn rate = 8.1e-07
Epoch:76 , learn rate = 8.1e-07
Epoch:77 , learn rate = 8.1e-07
Epoch:78 , learn rate = 8.1e-07
Epoch:79 , learn rate = 8.1e-07
Epoch:80 , learn rate = 8.1e-07
Epoch:81 , learn rate = 8.1e-07
Epoch:82 , learn rate = 8.1e-07
Epoch:83 , learn rate = 8.1e-07
Epoch:84 , learn rate = 8.1e-07
Epoch:85 , learn rate = 8.1e-07
Epoch:86 , learn rate = 8.1e-07
Epoch:87 , learn rate = 8.1e-07
Epoch:88 , learn rate = 8.1e-07
Epoch:89 , learn rate = 8.1e-07
Epoch:90 , learn rate = 8.1e-07
Epoch:91 , learn rate = 8.1e-07
Epoch:92 , learn rate = 8.1e-07
Epoch:93 , learn rate = 8.1e-07
Epoch:94 , learn rate = 8.1e-07
Epoch:95 , learn rate = 8.1e-07
Epoch:96 , learn rate = 8.1e-07
Epoch:97 , learn rate = 8.1e-07
Epoch:98 , learn rate = 8.1e-07
Epoch:99 , learn rate = 8.1e-07
学习率变化曲线如下:
2.5 torch.optim.lr_scheduler.ExponentialLR
每个epoch通过乘以参数gamma对学习率进行衰减,当last_epoch=-1时,将初始化学习率设置为lr。
1. 类形式
torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma, last_epoch=-1, verbose=False)
2. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- gamma:float,学习率衰减的乘法因子;
- last_epoch:int,最后一个epoch的索引,默认值为-1;
- verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
3. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
lr_list = []
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
scheduler.step()
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
Epoch:0 , learn rate = 0.0001
Epoch:1 , learn rate = 9e-05
Epoch:2 , learn rate = 8.1e-05
Epoch:3 , learn rate = 7.290000000000001e-05
Epoch:4 , learn rate = 6.561000000000002e-05
Epoch:5 , learn rate = 5.904900000000002e-05
Epoch:6 , learn rate = 5.314410000000002e-05
Epoch:7 , learn rate = 4.782969000000002e-05
Epoch:8 , learn rate = 4.304672100000002e-05
Epoch:9 , learn rate = 3.874204890000002e-05
Epoch:10 , learn rate = 3.4867844010000016e-05
Epoch:11 , learn rate = 3.138105960900002e-05
Epoch:12 , learn rate = 2.8242953648100018e-05
Epoch:13 , learn rate = 2.5418658283290016e-05
Epoch:14 , learn rate = 2.2876792454961016e-05
Epoch:15 , learn rate = 2.0589113209464913e-05
Epoch:16 , learn rate = 1.8530201888518422e-05
Epoch:17 , learn rate = 1.667718169966658e-05
Epoch:18 , learn rate = 1.5009463529699922e-05
Epoch:19 , learn rate = 1.350851717672993e-05
Epoch:20 , learn rate = 1.2157665459056937e-05
Epoch:21 , learn rate = 1.0941898913151244e-05
Epoch:22 , learn rate = 9.84770902183612e-06
Epoch:23 , learn rate = 8.862938119652508e-06
Epoch:24 , learn rate = 7.976644307687257e-06
Epoch:25 , learn rate = 7.1789798769185315e-06
Epoch:26 , learn rate = 6.461081889226678e-06
Epoch:27 , learn rate = 5.81497370030401e-06
Epoch:28 , learn rate = 5.23347633027361e-06
Epoch:29 , learn rate = 4.710128697246249e-06
Epoch:30 , learn rate = 4.239115827521624e-06
Epoch:31 , learn rate = 3.815204244769462e-06
Epoch:32 , learn rate = 3.4336838202925156e-06
Epoch:33 , learn rate = 3.090315438263264e-06
Epoch:34 , learn rate = 2.7812838944369375e-06
Epoch:35 , learn rate = 2.503155504993244e-06
Epoch:36 , learn rate = 2.2528399544939195e-06
Epoch:37 , learn rate = 2.0275559590445276e-06
Epoch:38 , learn rate = 1.824800363140075e-06
Epoch:39 , learn rate = 1.6423203268260674e-06
Epoch:40 , learn rate = 1.4780882941434607e-06
Epoch:41 , learn rate = 1.3302794647291146e-06
Epoch:42 , learn rate = 1.1972515182562032e-06
Epoch:43 , learn rate = 1.0775263664305828e-06
Epoch:44 , learn rate = 9.697737297875246e-07
Epoch:45 , learn rate = 8.727963568087721e-07
Epoch:46 , learn rate = 7.855167211278949e-07
Epoch:47 , learn rate = 7.069650490151055e-07
Epoch:48 , learn rate = 6.362685441135949e-07
Epoch:49 , learn rate = 5.726416897022355e-07
Epoch:50 , learn rate = 5.15377520732012e-07
Epoch:51 , learn rate = 4.638397686588108e-07
Epoch:52 , learn rate = 4.1745579179292974e-07
Epoch:53 , learn rate = 3.7571021261363677e-07
Epoch:54 , learn rate = 3.3813919135227313e-07
Epoch:55 , learn rate = 3.043252722170458e-07
Epoch:56 , learn rate = 2.7389274499534124e-07
Epoch:57 , learn rate = 2.465034704958071e-07
Epoch:58 , learn rate = 2.218531234462264e-07
Epoch:59 , learn rate = 1.9966781110160376e-07
Epoch:60 , learn rate = 1.7970102999144338e-07
Epoch:61 , learn rate = 1.6173092699229905e-07
Epoch:62 , learn rate = 1.4555783429306916e-07
Epoch:63 , learn rate = 1.3100205086376224e-07
Epoch:64 , learn rate = 1.1790184577738602e-07
Epoch:65 , learn rate = 1.0611166119964742e-07
Epoch:66 , learn rate = 9.550049507968268e-08
Epoch:67 , learn rate = 8.595044557171442e-08
Epoch:68 , learn rate = 7.735540101454298e-08
Epoch:69 , learn rate = 6.961986091308869e-08
Epoch:70 , learn rate = 6.265787482177982e-08
Epoch:71 , learn rate = 5.6392087339601844e-08
Epoch:72 , learn rate = 5.075287860564166e-08
Epoch:73 , learn rate = 4.567759074507749e-08
Epoch:74 , learn rate = 4.1109831670569744e-08
Epoch:75 , learn rate = 3.699884850351277e-08
Epoch:76 , learn rate = 3.3298963653161496e-08
Epoch:77 , learn rate = 2.996906728784535e-08
Epoch:78 , learn rate = 2.6972160559060813e-08
Epoch:79 , learn rate = 2.4274944503154732e-08
Epoch:80 , learn rate = 2.184745005283926e-08
Epoch:81 , learn rate = 1.9662705047555335e-08
Epoch:82 , learn rate = 1.7696434542799804e-08
Epoch:83 , learn rate = 1.5926791088519824e-08
Epoch:84 , learn rate = 1.4334111979667841e-08
Epoch:85 , learn rate = 1.2900700781701057e-08
Epoch:86 , learn rate = 1.1610630703530951e-08
Epoch:87 , learn rate = 1.0449567633177856e-08
Epoch:88 , learn rate = 9.40461086986007e-09
Epoch:89 , learn rate = 8.464149782874063e-09
Epoch:90 , learn rate = 7.617734804586658e-09
Epoch:91 , learn rate = 6.855961324127992e-09
Epoch:92 , learn rate = 6.170365191715193e-09
Epoch:93 , learn rate = 5.553328672543673e-09
Epoch:94 , learn rate = 4.997995805289306e-09
Epoch:95 , learn rate = 4.498196224760376e-09
Epoch:96 , learn rate = 4.048376602284338e-09
Epoch:97 , learn rate = 3.6435389420559043e-09
Epoch:98 , learn rate = 3.279185047850314e-09
Epoch:99 , learn rate = 2.951266543065283e-09
学习率变化曲线如下:
2.6 torch.optim.lr_scheduler.CosineAnnealingLR
使用余弦退火调整学习率
1. 类形式
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False)
2. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- T_max:int,最大迭代次数;
- eta_min:float,最低的学习率,默认值为0;
- last_epoch:int,最后一个epoch的索引,默认值为-1;
- verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
3. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
lr_list = []
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
scheduler.step()
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
Epoch:0 , learn rate = 0.0001
Epoch:1 , learn rate = 9.755282581475769e-05
Epoch:2 , learn rate = 9.045084971874737e-05
Epoch:3 , learn rate = 7.938926261462366e-05
Epoch:4 , learn rate = 6.545084971874737e-05
Epoch:5 , learn rate = 4.9999999999999996e-05
Epoch:6 , learn rate = 3.454915028125263e-05
Epoch:7 , learn rate = 2.0610737385376345e-05
Epoch:8 , learn rate = 9.549150281252631e-06
Epoch:9 , learn rate = 2.447174185242323e-06
Epoch:10 , learn rate = 0.0
Epoch:11 , learn rate = 2.4471741852423237e-06
Epoch:12 , learn rate = 9.549150281252672e-06
Epoch:13 , learn rate = 2.0610737385376434e-05
Epoch:14 , learn rate = 3.4549150281252785e-05
Epoch:15 , learn rate = 5.0000000000000226e-05
Epoch:16 , learn rate = 6.545084971874767e-05
Epoch:17 , learn rate = 7.938926261462401e-05
Epoch:18 , learn rate = 9.045084971874779e-05
Epoch:19 , learn rate = 9.755282581475812e-05
Epoch:20 , learn rate = 0.00010000000000000044
Epoch:21 , learn rate = 9.75528258147581e-05
Epoch:22 , learn rate = 9.04508497187478e-05
Epoch:23 , learn rate = 7.938926261462401e-05
Epoch:24 , learn rate = 6.545084971874767e-05
Epoch:25 , learn rate = 5.0000000000000226e-05
Epoch:26 , learn rate = 3.454915028125279e-05
Epoch:27 , learn rate = 2.0610737385376444e-05
Epoch:28 , learn rate = 9.54915028125268e-06
Epoch:29 , learn rate = 2.44717418524234e-06
Epoch:30 , learn rate = 0.0
Epoch:31 , learn rate = 2.4471741852423237e-06
Epoch:32 , learn rate = 9.54915028125264e-06
Epoch:33 , learn rate = 2.0610737385376376e-05
Epoch:34 , learn rate = 3.454915028125269e-05
Epoch:35 , learn rate = 5.0000000000000104e-05
Epoch:36 , learn rate = 6.545084971874752e-05
Epoch:37 , learn rate = 7.938926261462385e-05
Epoch:38 , learn rate = 9.045084971874758e-05
Epoch:39 , learn rate = 9.755282581475792e-05
Epoch:40 , learn rate = 0.00010000000000000026
Epoch:41 , learn rate = 9.755282581475797e-05
Epoch:42 , learn rate = 9.045084971874764e-05
Epoch:43 , learn rate = 7.938926261462389e-05
Epoch:44 , learn rate = 6.545084971874765e-05
Epoch:45 , learn rate = 5.000000000000015e-05
Epoch:46 , learn rate = 3.4549150281252744e-05
Epoch:47 , learn rate = 2.0610737385376345e-05
Epoch:48 , learn rate = 9.549150281252669e-06
Epoch:49 , learn rate = 2.447174185242363e-06
Epoch:50 , learn rate = 0.0
Epoch:51 , learn rate = 2.4471741852423237e-06
Epoch:52 , learn rate = 9.549150281252763e-06
Epoch:53 , learn rate = 2.0610737385376718e-05
Epoch:54 , learn rate = 3.4549150281253144e-05
Epoch:55 , learn rate = 5.000000000000085e-05
Epoch:56 , learn rate = 6.545084971874837e-05
Epoch:57 , learn rate = 7.938926261462481e-05
Epoch:58 , learn rate = 9.045084971874878e-05
Epoch:59 , learn rate = 9.755282581475916e-05
Epoch:60 , learn rate = 0.00010000000000000156
Epoch:61 , learn rate = 9.755282581475919e-05
Epoch:62 , learn rate = 9.045084971874882e-05
Epoch:63 , learn rate = 7.938926261462487e-05
Epoch:64 , learn rate = 6.545084971874843e-05
Epoch:65 , learn rate = 5.000000000000092e-05
Epoch:66 , learn rate = 3.4549150281253205e-05
Epoch:67 , learn rate = 2.0610737385376772e-05
Epoch:68 , learn rate = 9.549150281252806e-06
Epoch:69 , learn rate = 2.4471741852423453e-06
Epoch:70 , learn rate = 0.0
Epoch:71 , learn rate = 2.4471741852423237e-06
Epoch:72 , learn rate = 9.549150281252564e-06
Epoch:73 , learn rate = 2.0610737385376146e-05
Epoch:74 , learn rate = 3.454915028125243e-05
Epoch:75 , learn rate = 4.999999999999965e-05
Epoch:76 , learn rate = 6.545084971874704e-05
Epoch:77 , learn rate = 7.938926261462335e-05
Epoch:78 , learn rate = 9.045084971874696e-05
Epoch:79 , learn rate = 9.755282581475727e-05
Epoch:80 , learn rate = 9.999999999999957e-05
Epoch:81 , learn rate = 9.75528258147573e-05
Epoch:82 , learn rate = 9.045084971874712e-05
Epoch:83 , learn rate = 7.93892626146233e-05
Epoch:84 , learn rate = 6.545084971874714e-05
Epoch:85 , learn rate = 4.999999999999975e-05
Epoch:86 , learn rate = 3.454915028125254e-05
Epoch:87 , learn rate = 2.0610737385376376e-05
Epoch:88 , learn rate = 9.549150281252728e-06
Epoch:89 , learn rate = 2.4471741852423026e-06
Epoch:90 , learn rate = 0.0
Epoch:91 , learn rate = 2.4471741852423237e-06
Epoch:92 , learn rate = 9.549150281252769e-06
Epoch:93 , learn rate = 2.0610737385376603e-05
Epoch:94 , learn rate = 3.4549150281253375e-05
Epoch:95 , learn rate = 5.0000000000000944e-05
Epoch:96 , learn rate = 6.545084971874851e-05
Epoch:97 , learn rate = 7.938926261462497e-05
Epoch:98 , learn rate = 9.045084971874887e-05
Epoch:99 , learn rate = 9.755282581475945e-05
学习率变化曲线如下:
2.7 torch.optim.lr_scheduler.ReduceLROnPlateau
指定一个度量指标(比如验证数据集的loss或者准确率),当度量指标不再改变时降低学习率。
1. 类形式
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=0.0001, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-08, verbose=False)
2. 类参数
- optimizer:Optimizer对象。需要修改学习率的优化器;
- mode:str,在“min”或者“max”中选择,如果设置为min模式,当监测指标的数值停止减少时,降低学习率;如果设置为max模式,则当监测指标数量停止增加时,降低学习率,默认值为"min";
- factor:float,降低学习率的衰减因子,new_lr = lr * factor,默认值为0.1;
- patience:int,在监测指标在patience个epoch没有改变时,则降低学习率,默认值为10。比如当patience为2时,如果监测指标在2个epoch内没有改变,则先不降低学习率,如果第3个epoch还不改变,则降低学习率;
- threshold:float,衡量新最优值的阈值,仅关注重大变化。默认值为
1e-4
; - threshold_mode:str,在"rel"和"abs"中选择,如果设置为rel,在max模式下,dynamic_threshold = best * ( 1 + threshold ),在min模式下,best * ( 1 - threshold );如果设置为abs,在max模式下,dynamic_threshold = best + threshold,在min模式下,best - threshold。默认值为“rel”;
- cooldown:int,学习率下降之后恢复正常运行的epoch数量,默认值为0;
- min_lr:float或者list,标量或者标量列表,所有参数组或者每组学习率的最低值。默认值为0;
- eps:float,学习率的最小衰减,如果新旧学习率之间的差异小于eps,则不更新学习率。默认值
1e-8
; - verbose:bool,默认为False。如果设置为True,则每次更新都打印一条信息到控制台;
3. 使用示例
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.hidden = nn.Linear(1, 20)
self.predict = nn.Linear(20, 1)
def forward(self,x):
x = F.relu(self.hidden(x))
x = self.predict(x)
return x
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyNet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.99))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2, factor=0.85)
lr_list = []
monitor_val = 1000
epochs = 100
for epoch in range(epochs):
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])
print('Epoch:{} , learn rate = {}'.format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
optimizer.zero_grad()
optimizer.step()
if epoch % 5 == 0:
monitor_val -= 2
scheduler.step(monitor_val)
plt.title("learn rate demo")
plt.xlabel("epoch")
plt.ylabel("lr")
plt.plot(range(100), lr_list, color='r')
plt.show()
输出
Epoch:0 , learn rate = 0.0001
Epoch:1 , learn rate = 0.0001
Epoch:2 , learn rate = 0.0001
Epoch:3 , learn rate = 0.0001
Epoch:4 , learn rate = 8.5e-05
Epoch:5 , learn rate = 8.5e-05
Epoch:6 , learn rate = 8.5e-05
Epoch:7 , learn rate = 8.5e-05
Epoch:8 , learn rate = 8.5e-05
Epoch:9 , learn rate = 7.225000000000001e-05
Epoch:10 , learn rate = 7.225000000000001e-05
Epoch:11 , learn rate = 7.225000000000001e-05
Epoch:12 , learn rate = 7.225000000000001e-05
Epoch:13 , learn rate = 7.225000000000001e-05
Epoch:14 , learn rate = 6.141250000000001e-05
Epoch:15 , learn rate = 6.141250000000001e-05
Epoch:16 , learn rate = 6.141250000000001e-05
Epoch:17 , learn rate = 6.141250000000001e-05
Epoch:18 , learn rate = 6.141250000000001e-05
Epoch:19 , learn rate = 5.2200625000000005e-05
Epoch:20 , learn rate = 5.2200625000000005e-05
Epoch:21 , learn rate = 5.2200625000000005e-05
Epoch:22 , learn rate = 5.2200625000000005e-05
Epoch:23 , learn rate = 5.2200625000000005e-05
Epoch:24 , learn rate = 4.437053125e-05
Epoch:25 , learn rate = 4.437053125e-05
Epoch:26 , learn rate = 4.437053125e-05
Epoch:27 , learn rate = 4.437053125e-05
Epoch:28 , learn rate = 4.437053125e-05
Epoch:29 , learn rate = 3.77149515625e-05
Epoch:30 , learn rate = 3.77149515625e-05
Epoch:31 , learn rate = 3.77149515625e-05
Epoch:32 , learn rate = 3.77149515625e-05
Epoch:33 , learn rate = 3.77149515625e-05
Epoch:34 , learn rate = 3.2057708828124995e-05
Epoch:35 , learn rate = 3.2057708828124995e-05
Epoch:36 , learn rate = 3.2057708828124995e-05
Epoch:37 , learn rate = 3.2057708828124995e-05
Epoch:38 , learn rate = 3.2057708828124995e-05
Epoch:39 , learn rate = 2.7249052503906245e-05
Epoch:40 , learn rate = 2.7249052503906245e-05
Epoch:41 , learn rate = 2.7249052503906245e-05
Epoch:42 , learn rate = 2.7249052503906245e-05
Epoch:43 , learn rate = 2.7249052503906245e-05
Epoch:44 , learn rate = 2.3161694628320308e-05
Epoch:45 , learn rate = 2.3161694628320308e-05
Epoch:46 , learn rate = 2.3161694628320308e-05
Epoch:47 , learn rate = 2.3161694628320308e-05
Epoch:48 , learn rate = 2.3161694628320308e-05
Epoch:49 , learn rate = 1.9687440434072263e-05
Epoch:50 , learn rate = 1.9687440434072263e-05
Epoch:51 , learn rate = 1.9687440434072263e-05
Epoch:52 , learn rate = 1.9687440434072263e-05
Epoch:53 , learn rate = 1.9687440434072263e-05
Epoch:54 , learn rate = 1.673432436896142e-05
Epoch:55 , learn rate = 1.673432436896142e-05
Epoch:56 , learn rate = 1.673432436896142e-05
Epoch:57 , learn rate = 1.673432436896142e-05
Epoch:58 , learn rate = 1.673432436896142e-05
Epoch:59 , learn rate = 1.4224175713617208e-05
Epoch:60 , learn rate = 1.4224175713617208e-05
Epoch:61 , learn rate = 1.4224175713617208e-05
Epoch:62 , learn rate = 1.4224175713617208e-05
Epoch:63 , learn rate = 1.4224175713617208e-05
Epoch:64 , learn rate = 1.2090549356574626e-05
Epoch:65 , learn rate = 1.2090549356574626e-05
Epoch:66 , learn rate = 1.2090549356574626e-05
Epoch:67 , learn rate = 1.2090549356574626e-05
Epoch:68 , learn rate = 1.2090549356574626e-05
Epoch:69 , learn rate = 1.0276966953088432e-05
Epoch:70 , learn rate = 1.0276966953088432e-05
Epoch:71 , learn rate = 1.0276966953088432e-05
Epoch:72 , learn rate = 1.0276966953088432e-05
Epoch:73 , learn rate = 1.0276966953088432e-05
Epoch:74 , learn rate = 8.735421910125167e-06
Epoch:75 , learn rate = 8.735421910125167e-06
Epoch:76 , learn rate = 8.735421910125167e-06
Epoch:77 , learn rate = 8.735421910125167e-06
Epoch:78 , learn rate = 8.735421910125167e-06
Epoch:79 , learn rate = 7.425108623606392e-06
Epoch:80 , learn rate = 7.425108623606392e-06
Epoch:81 , learn rate = 7.425108623606392e-06
Epoch:82 , learn rate = 7.425108623606392e-06
Epoch:83 , learn rate = 7.425108623606392e-06
Epoch:84 , learn rate = 6.3113423300654325e-06
Epoch:85 , learn rate = 6.3113423300654325e-06
Epoch:86 , learn rate = 6.3113423300654325e-06
Epoch:87 , learn rate = 6.3113423300654325e-06
Epoch:88 , learn rate = 6.3113423300654325e-06
Epoch:89 , learn rate = 5.3646409805556175e-06
Epoch:90 , learn rate = 5.3646409805556175e-06
Epoch:91 , learn rate = 5.3646409805556175e-06
Epoch:92 , learn rate = 5.3646409805556175e-06
Epoch:93 , learn rate = 5.3646409805556175e-06
Epoch:94 , learn rate = 4.559944833472275e-06
Epoch:95 , learn rate = 4.559944833472275e-06
Epoch:96 , learn rate = 4.559944833472275e-06
Epoch:97 , learn rate = 4.559944833472275e-06
Epoch:98 , learn rate = 4.559944833472275e-06
Epoch:99 , learn rate = 3.875953108451433e-06
学习率变化曲线如下:
2.8 torch.optim.lr_scheduler.CyclicLR
2.9 torch.optim.lr_scheduler.OneCycleLR
2.10 torch.optim.lr_scheduler.CosineAnnealingWarmRestarts
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 手动调整学习率以及使用torch.optim.lr_scheduler调整学习率
原文链接:https://www.stubbornhuang.com/2229/
发布于:2022年08月04日 8:52:57
修改于:2023年06月25日 20:46:48
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
52