Pytorch – 内置的CTC损失函数torch.nn.CTCLoss参数详解与使用示例
CTC(Connectionist Temporal Classification)主要是处理不定长序列对齐问题,而CTCLoss主要是计算连续未分段的时间序列与目标序列之间的损失。CTCLoss对输入与目标可能对齐的概率求和,产生一个相对于每个输入节点可微分的损失值。假设输入到目标的对应关系是“多对一”的,那么这限制了目标序列的长度,因此目标序列的长度必须是小于或者等于输入长度。
Pytorch也支持torch.nn.CTCLoss
函数,下文将对torch.nn.CTCLoss
的创建和使用做出详细的解释,以及以一个简单的例子说明torch.nn.CTCLoss
的使用,并给出使用时的一些需要注意的点。
1 torch.nn.CTCLoss
1.1 创建torch.nn.CTCLoss
函数形式
torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)
函数参数
- blank:int类型可选参数,空白标签所在的label值,默认为0
- reduction:str类型可选参数,处理输出loss的方式,可选择
none
、mean
、sum
,none
表示不对输出loss做任何处理,mean
表示输出loss将会除以目标长度,取批次的平均值,sum
则表示对输出loss做求和处理。默认为'mean'。 - zero_infinity:bool类型,表示是否将无限loss和相关的梯度归零。无限loss主要会出现在因输入太短而无法与目标对齐的情况。默认值为False。
1.2 使用torch.nn.CTCLoss计算损失值
函数形式
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
函数参数
- log_probs:形状为(T,N,C)的Tensor,其中T=input length,N=batch size,C=包含blank空白标签在内的所有标签的总数量。注意,log_probs一般需要经过
torch.nn.functional.log_softmax
处理后再送入到CTCLoss中。 - targets:形状为(N,S)或者sum(target_lengths)的Tensor。如果大小为(N,S),则其中N=batch size,S=max target length,即S为标签的最大长度;如果大小为sum(target_lengths),则表示为所有标签的长度之和。本参数target表示目标序列,目标序列中的每一个元素为分类索引,并且目标序列中的元素不能为blank标签索引(blank标签索引默认为0)。如果是(N,S)的形式,则target应扩展填充(padding操作)到最长标签的长度,然后堆叠起来(stack操作)。如果是sum(target_lengths)形式,则各个标签不需要填充,然后连接成一维的Tensor。
- input_lengths:形状为N的Tensor或者Tuple(元祖),其中N=batch size。该参数为每一个输入Tensor的长度,每一个标签的长度都必须小于或者等于T。一般来说,如果每一个输入都被填充到(padding操作)相等长度T的情况下,则每一个输入Tensor的长度都是固定的,即为T。
- target_lengths:形状为N的Tensor或者Tuple(元祖),其中N=batch size。该参数表示目标序列的长度。一般来说,如果每一个输出都被填充到(padding操作)相等长度的情况下,则每一个输出Tensor的长度都是固定的。如果targets的形状为(N,S),则target_lengths为每一个目标序列的终止索引S_{n},例如对于每一个batch,target_{n} = targets[n,0:S_{n}]。如果目标为串联所有目标序列的一维张量,则target_lengths必须小于或者等于S,然后所有target_lengths相加的总长度必须等于该一维张量的长度。
1.3 使用示例
1.3.1 当targets的形状为(N,S)时
当targets的形状为(N,S)时,batch中的所有目标序列被统一填充为当前batch最长的目标序列的长度,然后再计算CTCLoss。
比如当前Batch Size为16,即当前Batch中有16个目标序列,假设这16个目标序列中最长的序列长度为30,最小的目标序列为10,分类总数量(包括blank)为20,输入序列长度为50,则
# -*- coding: utf-8 -*-
import torch
if __name__ == '__main__':
T = 50 # 输入序列长度
C = 20 # 分类总数量,包括blank
N = 16 # Batch Size
S = 30 # 在当前Batch中最长的目标序列在padding之后的长度
S_min = 10 # 最小的目标序列长度
# 初始化一个随机输入序列,形状为(T,N,C)=>(50,16,20)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# 初始化一个随机目标序列,blank=0,1:C=classes,形状为(N,S)=>(16,30)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
# 初始化输入序列长度Tensor,形状为N,值为T
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# 初始化一个随机目标序列长度,形状为N,值最小为10,最大为30
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
# 创建一个CTCLoss对象
ctc_loss = torch.nn.CTCLoss()
# 调用CTCLoss()对象计算损失值
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)
1.3.2 当targets的形状为sum(target_lengths)时
当targets的形状为sum(target_lengths)时,不需要将batch中的所有目标序列填充为最大目标序列长度,而是直接计算CTCLoss。
比如当前Batch Size为16,输入序列长度为50,分类总数量(包括blank)为20,则
# -*- coding: utf-8 -*-
import torch
if __name__ == '__main__':
T = 50 # 输入序列长度
C = 20 # 分类总数量,包括blank
N = 16 # Batch Size
# 初始化一个随机输入序列,形状为(T,N,C)=>(50,16,20)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# 初始化输入序列长度Tensor,形状为N,值为T
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# 初始化一个随机目标序列长度,形状为N,值最小为1,最大为T
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
# 初始化一个随机目标序列,blank=0,1:C=classes,形状为(N,S)=>(16,30)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
# 创建一个CTCLoss对象
ctc_loss = torch.nn.CTCLoss()
# 调用CTCLoss()对象计算损失值
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)
1.4 以车牌识别为例简单说明如何使用CTCLoss
假设我们需要进行中文车牌识别,中文车牌识别中包含了下面所有的字符,包含blank在内共68个分类。
ALL_CLASSES = ['-','京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
'新',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O']
其中blank标签为"-",并处于索引0的位置。
假设当前的batch size为4,当前batch中包含以下四张车牌
湘E269JY
冀PL3N67
川R67283F
津AD68429
那么这四张车牌对应的标签依次为
[19, 46, 34, 38, 41, 50, 64]
[5, 55, 52, 35, 54, 38, 39]
[23, 57, 38, 39, 34, 40, 35, 47]
[3, 42, 45, 38, 40, 36, 34, 41]
长度依次为
7
7
8
8
那么我们可以通过以下代码计算CTCLoss,将sum(target_lengths)作为target的形状,
# -*- coding: utf-8 -*-
import torch
import itertools
if __name__ == '__main__':
# 车牌中可能出现的所有词,包括blank标签-,其中blank标签的索引为0
ALL_CLASSES = ['-','京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
'新',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O']
# 假如batch size为4,即一个batch中有4个车牌
batch_target_groudtruth = [
['湘','E','2','6','9','J','Y'],
['冀','P','L','3','N','6','7'],
['川','R','6','7','2','8','3','F'],
['津','A','D','6','8','4','2','9']
]
# 根据车牌的文字信息获取当前batch中四个车牌的标签
target_label = []
for license_palte in batch_target_groudtruth:
temp_list = []
for x in license_palte:
for i,classes in enumerate(ALL_CLASSES):
if x == classes:
temp_list.append(i)
target_label.append(temp_list)
print(target_label)
# 获取四个车牌每个车牌的长度
target_lengths_list = [len(label) for label in target_label]
T = 50 # 输入序列长度
C = len(ALL_CLASSES) # 分类总数量,包括blank
N = 4 # Batch Size
# 初始化一个随机输入序列,形状为(T,N,C)=>(50,4,20)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# 初始化输入序列长度Tensor,形状为N,值为T
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# 以四个车牌的长度7,7,8,8作为目标序列长度
target_lengths = torch.tensor(target_lengths_list,dtype=torch.long)
# 以四个车牌的lable作为目标序列,采用sum(target_lengths)形式,形状为30的一维tensor
target = torch.tensor(list(itertools.chain.from_iterable(target_label)),dtype=torch.long)
# 创建一个CTCLoss对象
ctc_loss = torch.nn.CTCLoss(blank=0)
# 调用CTCLoss()对象计算损失值
loss = ctc_loss(input, target, input_lengths, target_lengths)
print(loss)
1.5 使用torch.nn.CTCLoss需要注意的点
- Pytorch torch.nn.CTCLoss官方文档这么说的:为了使用cuDNN,需要满足以下的条件。targets必须是concatenated format,所有的input_lengths的长度必须是T,blank必须等于0,target_lengths必须小于或者等于256,整形int参数必须明确类型(torch.int32)。
- blank空白标签一定要依据空白符在预测总字符集中的位置来设定,否则就会出错;
- 输出序列长度T尽量在模型设计时就要考虑到模型需要预测的最长序列,如需要预测的最长序列其长度为I,则理论上T应大于等于2I+1,这是因为CTCLoss假设在最坏情况下每个真实标签前后都至少有一个空白标签进行隔开以区分重复项;
- 输入的log_probs除了进行log_softmax()处理再送入CTCLoss外,还必须要调整其维度顺序,确保其shape为(T, N, C)!
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 内置的CTC损失函数torch.nn.CTCLoss参数详解与使用示例
原文链接:https://www.stubbornhuang.com/2175/
发布于:2022年06月21日 9:14:34
修改于:2023年06月25日 21:05:45
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50