深度学习 – 语音识别框架wenet源码wenet/utils/mask.py中的mask机制
在阅读工业级语音识别框架wenet的源码的过程中,wenet/utils/mask.py
中提供的各种mask函数非常重要,其实现了wenet论文Unified Streaming and Non-streaming Two-pass End-to-end Model
for Speech Recognition中3.2.2节Dynamic Chunk Training中创新的chunk mask机制,而基于这种机制,wenet实现了流式和非流式语音识别模型训练的统一。
在wenet/utils/mask.py
这个工具脚本中,目前提供了如下的函数:
- subsequent_mask
- subsequent_chunk_mask
- add_optional_chunk_mask
- make_pad_mask
- make_non_pad_mask
- mask_finished_scores
- mask_finished_preds
本文将详细介绍上述提到函数的作用。
1 mask.py
1.1 subsequent_mask
1. 函数代码
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
2. 函数功能
通过指定入参size
,返回一个形状为(size,size)且上三角为False(0)的mask Tensor。
3. 函数参数
- size:int,mask tensor的大小
- device:str,“cpu”或者"cuda",指定在cpu或者gpu上创建返回的Tensor。默认值为"cpu"。
4. 函数返回值
返回一个形状为(size,size)且上三角为False(0)的Tensor。
5. 使用示例
import torch
def subsequent_mask(
size: int,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size).
This mask is used only in decoder which works in an auto-regressive mode.
This means the current step could only do attention with its left steps.
In encoder, fully attention is used when streaming is not necessary and
the sequence is not long. In this case, no attention mask is needed.
When streaming is need, chunk-based attention is used in encoder. See
subsequent_chunk_mask for the chunk-based attention mask.
Args:
size (int): size of mask
str device (str): "cpu" or "cuda" or torch.Tensor.device
dtype (torch.device): result dtype
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
arange = torch.arange(size, device=device)
mask = arange.expand(size, size)
arange = arange.unsqueeze(-1)
mask = mask <= arange
return mask
if __name__ == '__main__':
out = subsequent_mask(5)
print(out)
输出
tensor([[ True, False, False, False, False],
[ True, True, False, False, False],
[ True, True, True, False, False],
[ True, True, True, True, False],
[ True, True, True, True, True]])
1.2 subsequent_chunk_mask
1. 函数代码
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
2. 函数功能
使用设定的chunk_size
创建形状为(size,size)的mask Tensor。
3. 函数参数
- size:int,mask tensor的大小
- chunk_size:int,chunk的大小
- num_left_chunks:int,左侧chunk的数量,如果
num_left_chunks
小于0,则使用左侧全部的chunk;如果num_left_chunks
大于等于0,则使用num_left_chunks
个chunk。默认值为-1。 - device:str,“cpu”或者"cuda",指定在cpu或者gpu上创建返回的Tensor。默认值为"cpu"。
4. 函数返回值
返回一个形状为(size,size)使用chunk的Tensor。
5. 使用示例
import torch
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
if __name__ == '__main__':
print('use full chunk:')
out_use_full_chunk = subsequent_chunk_mask(10, 2)
print('{}\n'.format(out_use_full_chunk))
print('use_1_left_chunk:')
out_use_num_left_chunk = subsequent_chunk_mask(10, 2, 1)
print('{}\n'.format(out_use_num_left_chunk))
print('use_2_left_chunk:')
out_use_num_left_chunk = subsequent_chunk_mask(10, 2, 2)
print('{}\n'.format(out_use_num_left_chunk))
输出
use full chunk:
tensor([[ True, True, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, False, False],
[ True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True, True, True]])
use_1_left_chunk:
tensor([[ True, True, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[False, False, True, True, True, True, False, False, False, False],
[False, False, True, True, True, True, False, False, False, False],
[False, False, False, False, True, True, True, True, False, False],
[False, False, False, False, True, True, True, True, False, False],
[False, False, False, False, False, False, True, True, True, True],
[False, False, False, False, False, False, True, True, True, True]])
use_2_left_chunk:
tensor([[ True, True, False, False, False, False, False, False, False, False],
[ True, True, False, False, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, False, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[False, False, True, True, True, True, True, True, False, False],
[False, False, True, True, True, True, True, True, False, False],
[False, False, False, False, True, True, True, True, True, True],
[False, False, False, False, True, True, True, True, True, True]])
1.3 make_pad_mask
1. 函数代码
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
2. 函数功能
根据给定一个batch的序列长度Tensor,返回padding部分为True,未padding部分为False的mask Tensor
3. 函数参数
- lengths:tensor,一个batch的序列长度Tensor,形状为[B,],比如[10,8,5,3,1]
- max_len:int,序列的最大长度,如果max_len等于0,则使用当前batch中所有序列长度最大的作为最大序列长度;如果max_len大于0,则使用max_len作为最大序列长度;
4. 函数返回值
返回padding部分为True(1),未padding部分为False(0)的mask Tensor
5. 使用示例
import torch
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
if __name__ == '__main__':
print('not set max_len:')
out = make_pad_mask(torch.tensor([10,5,3]))
print('{},{}\n'.format(out,out.shape))
print('set max_len = 8:')
out = make_pad_mask(torch.tensor([10,5,3]),max_len=8)
print('{},{}\n'.format(out,out.shape))
输出
not set max_len:
tensor([[False, False, False, False, False, False, False, False, False, False],
[False, False, False, False, False, True, True, True, True, True],
[False, False, False, True, True, True, True, True, True, True]]),torch.Size([3, 10])
set max_len = 8:
tensor([[False, False, False, False, False, False, False, False],
[False, False, False, False, False, True, True, True],
[False, False, False, True, True, True, True, True]]),torch.Size([3, 8])
1.4 make_non_pad_mask
1. 函数代码
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
2. 函数功能
根据给定一个batch的序列长度Tensor,返回padding部分为False(0),未padding部分为True(1)的mask Tensor
3. 函数参数
- lengths:tensor,一个batch的序列长度Tensor,形状为[B,],比如[10,8,5,3,1]
4. 函数返回值
返回以batch中最大序列长度为标准的padding部分为False(0),未padding部分为True(1)的mask Tensor
5. 使用示例
import torch
def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
"""Make mask tensor containing indices of non-padded part.
The sequences in a batch may have different lengths. To enable
batch computing, padding is need to make all sequence in same
size. To avoid the padding part pass value to context dependent
block such as attention or convolution , this padding part is
masked.
This pad_mask is used in both encoder and decoder.
1 for non-padded part and 0 for padded part.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
"""
return ~make_pad_mask(lengths)
if __name__ == '__main__':
out = make_non_pad_mask(torch.tensor([10,5,3]))
print('{},{}\n'.format(out,out.shape))
输出
tensor([[ True, True, True, True, True, True, True, True, True, True],
[ True, True, True, True, True, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False]]),torch.Size([3, 10])
1.5 add_optional_chunk_mask
1. 函数代码
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int, static_chunk_size: int,
num_decoding_left_chunks: int):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
2. 函数功能
根据所设置的参数,返回输入xs对应的形状为(B,L,L)的chunk mask Tensor
3. 函数参数
- xs:tensor,已padding后的Tensor,形状为(B,L,D),其中L为最大长度
- masks:tensor,xs对应的padding mask,形状为(B,1,L)
- use_dynamic_chunk:是否使用dynamic chunk
- use_dynamic_left_chunk:是否使用dynamic left chunk
- decoding_chunk_size:dynamic chunk的decode chunk大小,在模型训练时默认为0,使用随机的dynamic chunk;如果decode_chunk_size小于0,使用full chunk;如果decode_chunk_size大于0,使用固定的decode_chunk_size
- static_chunk_size:在static chunk训练和解码的chunk size,当use_dynamic_chunk为true时,不管static_chunk_size设置为什么值都会被忽略
- num_decoding_left_chunks:在解码时需用到多少个left chunk,解码的chunk_size为decode_chunk_size,如果num_decoding_left_chunks小于0,则使用全部的left chunk;如果num_deocding_left_chunks大于等于0,则使用num_deocding_left_chunks个left chunk
4. 函数返回值
返回输入xs对应的形状为(B,L,L)的chunk mask Tensor
5. 使用示例
import torch
def subsequent_chunk_mask(
size: int,
chunk_size: int,
num_left_chunks: int = -1,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Create mask for subsequent steps (size, size) with chunk size,
this is for streaming encoder
Args:
size (int): size of mask
chunk_size (int): size of chunk
num_left_chunks (int): number of left chunks
<0: use full chunk
>=0: use num_left_chunks
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
Returns:
torch.Tensor: mask
Examples:
>>> subsequent_chunk_mask(4, 2)
[[1, 1, 0, 0],
[1, 1, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1]]
"""
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
for i in range(size):
if num_left_chunks < 0:
start = 0
else:
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
ending = min((i // chunk_size + 1) * chunk_size, size)
ret[i, start:ending] = True
return ret
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""Make mask tensor containing indices of padded part.
See description of make_non_pad_mask.
Args:
lengths (torch.Tensor): Batch of lengths (B,).
Returns:
torch.Tensor: Mask tensor containing indices of padded part.
Examples:
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
"""
batch_size = lengths.size(0)
max_len = max_len if max_len > 0 else lengths.max().item()
seq_range = torch.arange(0,
max_len,
dtype=torch.int64,
device=lengths.device)
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
seq_length_expand = lengths.unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
use_dynamic_chunk: bool,
use_dynamic_left_chunk: bool,
decoding_chunk_size: int, static_chunk_size: int,
num_decoding_left_chunks: int):
""" Apply optional mask for encoder.
Args:
xs (torch.Tensor): padded input, (B, L, D), L for max length
mask (torch.Tensor): mask for xs, (B, 1, L)
use_dynamic_chunk (bool): whether to use dynamic chunk or not
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
training.
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
0: default for training, use random dynamic chunk.
<0: for decoding, use full chunk.
>0: for decoding, use fixed chunk size as set.
static_chunk_size (int): chunk size for static chunk training/decoding
if it's greater than 0, if use_dynamic_chunk is true,
this parameter will be ignored
num_decoding_left_chunks: number of left chunks, this is for decoding,
the chunk size is decoding_chunk_size.
>=0: use num_decoding_left_chunks
<0: use all left chunks
Returns:
torch.Tensor: chunk mask of the input xs.
"""
# Whether to use chunk mask or not
if use_dynamic_chunk:
max_len = xs.size(1)
if decoding_chunk_size < 0:
chunk_size = max_len
num_left_chunks = -1
elif decoding_chunk_size > 0:
chunk_size = decoding_chunk_size
num_left_chunks = num_decoding_left_chunks
else:
# chunk size is either [1, 25] or full context(max_len).
# Since we use 4 times subsampling and allow up to 1s(100 frames)
# delay, the maximum frame is 100 / 4 = 25.
chunk_size = torch.randint(1, max_len, (1, )).item()
num_left_chunks = -1
if chunk_size > max_len // 2:
chunk_size = max_len
else:
chunk_size = chunk_size % 25 + 1
if use_dynamic_left_chunk:
max_left_chunks = (max_len - 1) // chunk_size
num_left_chunks = torch.randint(0, max_left_chunks,
(1, )).item()
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
elif static_chunk_size > 0:
num_left_chunks = num_decoding_left_chunks
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
num_left_chunks,
xs.device) # (L, L)
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
chunk_masks = masks & chunk_masks # (B, L, L)
else:
chunk_masks = masks
return chunk_masks
if __name__ == '__main__':
input = torch.rand(size=(1,10,10))
input_length = torch.tensor([10])
padding_mask = ~make_pad_mask(input_length,max_len=10).unsqueeze(1)
print('padding_mask:')
print('{}\n'.format(padding_mask))
chunk_mask = add_optional_chunk_mask(xs=input,
masks=padding_mask,
use_dynamic_chunk=False,
use_dynamic_left_chunk=False,
decoding_chunk_size=-1,
static_chunk_size=3,
num_decoding_left_chunks=-1)
print('static_chunk_size = 3, full left chunk')
print('{}\n'.format(chunk_mask))
chunk_mask = add_optional_chunk_mask(xs=input,
masks=padding_mask,
use_dynamic_chunk=False,
use_dynamic_left_chunk=False,
decoding_chunk_size=-1,
static_chunk_size=3,
num_decoding_left_chunks=1)
print('static_chunk_size = 3, 1 left chunk')
print('{}\n'.format(chunk_mask))
输出
padding_mask:
tensor([[[True, True, True, True, True, True, True, True, True, True]]])
static_chunk_size = 3, full left chunk
tensor([[[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, False],
[ True, True, True, True, True, True, True, True, True, True]]])
static_chunk_size = 3, 1 left chunk
tensor([[[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, False, False, False, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[ True, True, True, True, True, True, False, False, False, False],
[False, False, False, True, True, True, True, True, True, False],
[False, False, False, True, True, True, True, True, True, False],
[False, False, False, True, True, True, True, True, True, False],
[False, False, False, False, False, False, True, True, True, True]]])
1.6 mask_finished_scores
1. 函数代码
def mask_finished_scores(score: torch.Tensor,
flag: torch.Tensor) -> torch.Tensor:
"""
If a sequence is finished, we only allow one alive branch. This function
aims to give one branch a zero score and the rest -inf score.
Args:
score (torch.Tensor): A real value array with shape
(batch_size * beam_size, beam_size).
flag (torch.Tensor): A bool array with shape
(batch_size * beam_size, 1).
Returns:
torch.Tensor: (batch_size * beam_size, beam_size).
"""
beam_size = score.size(-1)
zero_mask = torch.zeros_like(flag, dtype=torch.bool)
if beam_size > 1:
unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
dim=1)
finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
dim=1)
else:
unfinished = zero_mask
finished = flag
score.masked_fill_(unfinished, -float('inf'))
score.masked_fill_(finished, 0)
return score
2. 函数功能
当一个sequence完成之后,只允许保存一个alive branch。这个函数的主要作用是将一个branch的score置为0,而其他的branch的score置为-inf
3. 函数参数
- score:tensor,形状为
(batch_size * beam_size , beam_size)
的实数数组 - flag:tensor,形状为
(batch_size * beam_size , 1)
的bool型数组
4. 函数返回值
返回一个形状为(batch_size * beam_size , beam_size)
的Tensor
5. 使用示例
import torch
def mask_finished_scores(score: torch.Tensor,
flag: torch.Tensor) -> torch.Tensor:
"""
If a sequence is finished, we only allow one alive branch. This function
aims to give one branch a zero score and the rest -inf score.
Args:
score (torch.Tensor): A real value array with shape
(batch_size * beam_size, beam_size).
flag (torch.Tensor): A bool array with shape
(batch_size * beam_size, 1).
Returns:
torch.Tensor: (batch_size * beam_size, beam_size).
"""
beam_size = score.size(-1)
zero_mask = torch.zeros_like(flag, dtype=torch.bool)
if beam_size > 1:
unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
dim=1)
finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
dim=1)
else:
unfinished = zero_mask
finished = flag
score.masked_fill_(unfinished, -float('inf'))
score.masked_fill_(finished, 0)
return score
if __name__ == '__main__':
batch_size = 1
beam_size = 5
score = torch.rand(size=(5,5))
flag = torch.tensor([True,False,True,True,False]).unsqueeze(1)
out = mask_finished_scores(score=score,flag=flag)
print(out.shape)
print(out)
输出
torch.Size([5, 5])
tensor([[0.0000, -inf, -inf, -inf, -inf],
[0.6803, 0.0728, 0.4417, 0.4770, 0.1615],
[0.0000, -inf, -inf, -inf, -inf],
[0.0000, -inf, -inf, -inf, -inf],
[0.5874, 0.7472, 0.1647, 0.8151, 0.6123]])
1.7 mask_finished_preds
1. 函数代码
def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
eos: int) -> torch.Tensor:
"""
If a sequence is finished, all of its branch should be <eos>
Args:
pred (torch.Tensor): A int array with shape
(batch_size * beam_size, beam_size).
flag (torch.Tensor): A bool array with shape
(batch_size * beam_size, 1).
Returns:
torch.Tensor: (batch_size * beam_size).
"""
beam_size = pred.size(-1)
finished = flag.repeat([1, beam_size])
return pred.masked_fill_(finished, eos)
2. 函数功能
如果sequence完成,那么该序列的所有branch都应该为<eos>
3. 函数参数
- pred:tensor,形状为
(batch_size * beam_size , beam_size)
的int型数组 - flag:tensor,形状为
(batch_size * beam_size , 1)
的bool型数组 - eos:int,
<eos>
对应的id
4. 函数返回值
返回一个形状为(batch_size , beam_size)
的Tensor
5. 使用示例
import torch
def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
eos: int) -> torch.Tensor:
"""
If a sequence is finished, all of its branch should be <eos>
Args:
pred (torch.Tensor): A int array with shape
(batch_size * beam_size, beam_size).
flag (torch.Tensor): A bool array with shape
(batch_size * beam_size, 1).
Returns:
torch.Tensor: (batch_size * beam_size).
"""
beam_size = pred.size(-1)
finished = flag.repeat([1, beam_size])
return pred.masked_fill_(finished, eos)
if __name__ == '__main__':
batch_size = 1
beam_size = 5
score = torch.randint(low=1,high=5,size=(5,5))
flag = torch.tensor([True,False,True,True,False]).unsqueeze(1)
out = mask_finished_preds(pred=score,flag=flag,eos=6)
print(out.shape)
print(out)
输出
torch.Size([5, 5])
tensor([[6, 6, 6, 6, 6],
[2, 3, 2, 1, 2],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[1, 3, 2, 2, 4]])
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:深度学习 – 语音识别框架wenet源码wenet/utils/mask.py中的mask机制
原文链接:https://www.stubbornhuang.com/2294/
发布于:2022年08月10日 9:05:49
修改于:2023年06月25日 20:44:05
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
52