在阅读工业级语音识别框架wenet的源码的过程中,wenet/utils/mask.py中提供的各种mask函数非常重要,其实现了wenet论文Unified Streaming and Non-streaming Two-pass End-to-end Model
for Speech Recognition
中3.2.2节Dynamic Chunk Training中创新的chunk mask机制,而基于这种机制,wenet实现了流式和非流式语音识别模型训练的统一。

wenet/utils/mask.py这个工具脚本中,目前提供了如下的函数:

  • subsequent_mask
  • subsequent_chunk_mask
  • add_optional_chunk_mask
  • make_pad_mask
  • make_non_pad_mask
  • mask_finished_scores
  • mask_finished_preds

本文将详细介绍上述提到函数的作用。

1 mask.py

1.1 subsequent_mask

1. 函数代码

def subsequent_mask(
        size: int,
        device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size).

    This mask is used only in decoder which works in an auto-regressive mode.
    This means the current step could only do attention with its left steps.

    In encoder, fully attention is used when streaming is not necessary and
    the sequence is not long. In this  case, no attention mask is needed.

    When streaming is need, chunk-based attention is used in encoder. See
    subsequent_chunk_mask for the chunk-based attention mask.

    Args:
        size (int): size of mask
        str device (str): "cpu" or "cuda" or torch.Tensor.device
        dtype (torch.device): result dtype

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_mask(3)
        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]]
    """
    arange = torch.arange(size, device=device)
    mask = arange.expand(size, size)
    arange = arange.unsqueeze(-1)
    mask = mask <= arange
    return mask

2. 函数功能

通过指定入参size,返回一个形状为(size,size)且上三角为False(0)的mask Tensor。

3. 函数参数

  • size:int,mask tensor的大小
  • device:str,“cpu”或者"cuda",指定在cpu或者gpu上创建返回的Tensor。默认值为"cpu"。

4. 函数返回值

返回一个形状为(size,size)且上三角为False(0)的Tensor。

5. 使用示例

import torch

def subsequent_mask(
        size: int,
        device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size).

    This mask is used only in decoder which works in an auto-regressive mode.
    This means the current step could only do attention with its left steps.

    In encoder, fully attention is used when streaming is not necessary and
    the sequence is not long. In this  case, no attention mask is needed.

    When streaming is need, chunk-based attention is used in encoder. See
    subsequent_chunk_mask for the chunk-based attention mask.

    Args:
        size (int): size of mask
        str device (str): "cpu" or "cuda" or torch.Tensor.device
        dtype (torch.device): result dtype

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_mask(3)
        [[1, 0, 0],
         [1, 1, 0],
         [1, 1, 1]]
    """
    arange = torch.arange(size, device=device)
    mask = arange.expand(size, size)
    arange = arange.unsqueeze(-1)
    mask = mask <= arange
    return mask

if __name__ == '__main__':
    out = subsequent_mask(5)
    print(out)

输出

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [ True,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [ True,  True,  True,  True,  True]])

1.2 subsequent_chunk_mask

1. 函数代码

def subsequent_chunk_mask(
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size) with chunk size,
       this is for streaming encoder

    Args:
        size (int): size of mask
        chunk_size (int): size of chunk
        num_left_chunks (int): number of left chunks
            <0: use full chunk
            >=0: use num_left_chunks
        device (torch.device): "cpu" or "cuda" or torch.Tensor.device

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_chunk_mask(4, 2)
        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]
    """
    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if num_left_chunks < 0:
            start = 0
        else:
            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
        ending = min((i // chunk_size + 1) * chunk_size, size)
        ret[i, start:ending] = True
    return ret

2. 函数功能

使用设定的chunk_size创建形状为(size,size)的mask Tensor。

3. 函数参数

  • size:int,mask tensor的大小
  • chunk_size:int,chunk的大小
  • num_left_chunks:int,左侧chunk的数量,如果num_left_chunks小于0,则使用左侧全部的chunk;如果num_left_chunks大于等于0,则使用num_left_chunks个chunk。默认值为-1。
  • device:str,“cpu”或者"cuda",指定在cpu或者gpu上创建返回的Tensor。默认值为"cpu"。

4. 函数返回值

返回一个形状为(size,size)使用chunk的Tensor。

5. 使用示例

import torch

def subsequent_chunk_mask(
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size) with chunk size,
       this is for streaming encoder

    Args:
        size (int): size of mask
        chunk_size (int): size of chunk
        num_left_chunks (int): number of left chunks
            <0: use full chunk
            >=0: use num_left_chunks
        device (torch.device): "cpu" or "cuda" or torch.Tensor.device

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_chunk_mask(4, 2)
        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]
    """
    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if num_left_chunks < 0:
            start = 0
        else:
            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
        ending = min((i // chunk_size + 1) * chunk_size, size)
        ret[i, start:ending] = True
    return ret

if __name__ == '__main__':
    print('use full chunk:')
    out_use_full_chunk = subsequent_chunk_mask(10, 2)
    print('{}\n'.format(out_use_full_chunk))


    print('use_1_left_chunk:')
    out_use_num_left_chunk = subsequent_chunk_mask(10, 2, 1)
    print('{}\n'.format(out_use_num_left_chunk))

    print('use_2_left_chunk:')
    out_use_num_left_chunk = subsequent_chunk_mask(10, 2, 2)
    print('{}\n'.format(out_use_num_left_chunk))

输出

use full chunk:
tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]])

use_1_left_chunk:
tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [False, False,  True,  True,  True,  True, False, False, False, False],
        [False, False,  True,  True,  True,  True, False, False, False, False],
        [False, False, False, False,  True,  True,  True,  True, False, False],
        [False, False, False, False,  True,  True,  True,  True, False, False],
        [False, False, False, False, False, False,  True,  True,  True,  True],
        [False, False, False, False, False, False,  True,  True,  True,  True]])

use_2_left_chunk:
tensor([[ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True, False, False, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True, False, False, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [ True,  True,  True,  True,  True,  True, False, False, False, False],
        [False, False,  True,  True,  True,  True,  True,  True, False, False],
        [False, False,  True,  True,  True,  True,  True,  True, False, False],
        [False, False, False, False,  True,  True,  True,  True,  True,  True],
        [False, False, False, False,  True,  True,  True,  True,  True,  True]])

1.3 make_pad_mask

1. 函数代码

def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0,
                             max_len,
                             dtype=torch.int64,
                             device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask

2. 函数功能

根据给定一个batch的序列长度Tensor,返回padding部分为True,未padding部分为False的mask Tensor

3. 函数参数

  • lengths:tensor,一个batch的序列长度Tensor,形状为[B,],比如[10,8,5,3,1]
  • max_len:int,序列的最大长度,如果max_len等于0,则使用当前batch中所有序列长度最大的作为最大序列长度;如果max_len大于0,则使用max_len作为最大序列长度;

4. 函数返回值

返回padding部分为True(1),未padding部分为False(0)的mask Tensor

5. 使用示例

import torch

def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0,
                             max_len,
                             dtype=torch.int64,
                             device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask

if __name__ == '__main__':
    print('not set max_len:')
    out = make_pad_mask(torch.tensor([10,5,3]))
    print('{},{}\n'.format(out,out.shape))

    print('set max_len = 8:')
    out = make_pad_mask(torch.tensor([10,5,3]),max_len=8)
    print('{},{}\n'.format(out,out.shape))

输出

not set max_len:
tensor([[False, False, False, False, False, False, False, False, False, False],
        [False, False, False, False, False,  True,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True,  True,  True]]),torch.Size([3, 10])

set max_len = 8:
tensor([[False, False, False, False, False, False, False, False],
        [False, False, False, False, False,  True,  True,  True],
        [False, False, False,  True,  True,  True,  True,  True]]),torch.Size([3, 8])

1.4 make_non_pad_mask

1. 函数代码

def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
    """Make mask tensor containing indices of non-padded part.

    The sequences in a batch may have different lengths. To enable
    batch computing, padding is need to make all sequence in same
    size. To avoid the padding part pass value to context dependent
    block such as attention or convolution , this padding part is
    masked.

    This pad_mask is used in both encoder and decoder.

    1 for non-padded part and 0 for padded part.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_non_pad_mask(lengths)
        masks = [[1, 1, 1, 1 ,1],
                 [1, 1, 1, 0, 0],
                 [1, 1, 0, 0, 0]]
    """
    return ~make_pad_mask(lengths)

2. 函数功能

根据给定一个batch的序列长度Tensor,返回padding部分为False(0),未padding部分为True(1)的mask Tensor

3. 函数参数

  • lengths:tensor,一个batch的序列长度Tensor,形状为[B,],比如[10,8,5,3,1]

4. 函数返回值

返回以batch中最大序列长度为标准的padding部分为False(0),未padding部分为True(1)的mask Tensor

5. 使用示例

import torch

def make_non_pad_mask(lengths: torch.Tensor) -> torch.Tensor:
    """Make mask tensor containing indices of non-padded part.

    The sequences in a batch may have different lengths. To enable
    batch computing, padding is need to make all sequence in same
    size. To avoid the padding part pass value to context dependent
    block such as attention or convolution , this padding part is
    masked.

    This pad_mask is used in both encoder and decoder.

    1 for non-padded part and 0 for padded part.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_non_pad_mask(lengths)
        masks = [[1, 1, 1, 1 ,1],
                 [1, 1, 1, 0, 0],
                 [1, 1, 0, 0, 0]]
    """
    return ~make_pad_mask(lengths)

if __name__ == '__main__':
    out = make_non_pad_mask(torch.tensor([10,5,3]))
    print('{},{}\n'.format(out,out.shape))

输出

tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False, False, False, False, False],
        [ True,  True,  True, False, False, False, False, False, False, False]]),torch.Size([3, 10])

1.5 add_optional_chunk_mask

1. 函数代码

def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
                            use_dynamic_chunk: bool,
                            use_dynamic_left_chunk: bool,
                            decoding_chunk_size: int, static_chunk_size: int,
                            num_decoding_left_chunks: int):
    """ Apply optional mask for encoder.

    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
            training.
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks

    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    else:
        chunk_masks = masks
    return chunk_masks

2. 函数功能

根据所设置的参数,返回输入xs对应的形状为(B,L,L)的chunk mask Tensor

3. 函数参数

  • xs:tensor,已padding后的Tensor,形状为(B,L,D),其中L为最大长度
  • masks:tensor,xs对应的padding mask,形状为(B,1,L)
  • use_dynamic_chunk:是否使用dynamic chunk
  • use_dynamic_left_chunk:是否使用dynamic left chunk
  • decoding_chunk_size:dynamic chunk的decode chunk大小,在模型训练时默认为0,使用随机的dynamic chunk;如果decode_chunk_size小于0,使用full chunk;如果decode_chunk_size大于0,使用固定的decode_chunk_size
  • static_chunk_size:在static chunk训练和解码的chunk size,当use_dynamic_chunk为true时,不管static_chunk_size设置为什么值都会被忽略
  • num_decoding_left_chunks:在解码时需用到多少个left chunk,解码的chunk_size为decode_chunk_size,如果num_decoding_left_chunks小于0,则使用全部的left chunk;如果num_deocding_left_chunks大于等于0,则使用num_deocding_left_chunks个left chunk

4. 函数返回值

返回输入xs对应的形状为(B,L,L)的chunk mask Tensor

5. 使用示例

import torch

def subsequent_chunk_mask(
        size: int,
        chunk_size: int,
        num_left_chunks: int = -1,
        device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
    """Create mask for subsequent steps (size, size) with chunk size,
       this is for streaming encoder

    Args:
        size (int): size of mask
        chunk_size (int): size of chunk
        num_left_chunks (int): number of left chunks
            <0: use full chunk
            >=0: use num_left_chunks
        device (torch.device): "cpu" or "cuda" or torch.Tensor.device

    Returns:
        torch.Tensor: mask

    Examples:
        >>> subsequent_chunk_mask(4, 2)
        [[1, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]
    """
    ret = torch.zeros(size, size, device=device, dtype=torch.bool)
    for i in range(size):
        if num_left_chunks < 0:
            start = 0
        else:
            start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
        ending = min((i // chunk_size + 1) * chunk_size, size)
        ret[i, start:ending] = True
    return ret

def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
    """Make mask tensor containing indices of padded part.

    See description of make_non_pad_mask.

    Args:
        lengths (torch.Tensor): Batch of lengths (B,).
    Returns:
        torch.Tensor: Mask tensor containing indices of padded part.

    Examples:
        >>> lengths = [5, 3, 2]
        >>> make_pad_mask(lengths)
        masks = [[0, 0, 0, 0 ,0],
                 [0, 0, 0, 1, 1],
                 [0, 0, 1, 1, 1]]
    """
    batch_size = lengths.size(0)
    max_len = max_len if max_len > 0 else lengths.max().item()
    seq_range = torch.arange(0,
                             max_len,
                             dtype=torch.int64,
                             device=lengths.device)
    seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
    seq_length_expand = lengths.unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    return mask

def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
                            use_dynamic_chunk: bool,
                            use_dynamic_left_chunk: bool,
                            decoding_chunk_size: int, static_chunk_size: int,
                            num_decoding_left_chunks: int):
    """ Apply optional mask for encoder.

    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
            training.
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks

    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    else:
        chunk_masks = masks
    return chunk_masks

if __name__ == '__main__':
    input = torch.rand(size=(1,10,10))
    input_length = torch.tensor([10])
    padding_mask = ~make_pad_mask(input_length,max_len=10).unsqueeze(1)

    print('padding_mask:')
    print('{}\n'.format(padding_mask))

    chunk_mask = add_optional_chunk_mask(xs=input,
                                  masks=padding_mask,
                                  use_dynamic_chunk=False,
                                  use_dynamic_left_chunk=False,
                                  decoding_chunk_size=-1,
                                  static_chunk_size=3,
                                  num_decoding_left_chunks=-1)

    print('static_chunk_size = 3, full left chunk')
    print('{}\n'.format(chunk_mask))

    chunk_mask = add_optional_chunk_mask(xs=input,
                                  masks=padding_mask,
                                  use_dynamic_chunk=False,
                                  use_dynamic_left_chunk=False,
                                  decoding_chunk_size=-1,
                                  static_chunk_size=3,
                                  num_decoding_left_chunks=1)
    print('static_chunk_size = 3, 1 left chunk')
    print('{}\n'.format(chunk_mask))

输出

padding_mask:
tensor([[[True, True, True, True, True, True, True, True, True, True]]])

static_chunk_size = 3, full left chunk
tensor([[[ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]]])

static_chunk_size = 3, 1 left chunk
tensor([[[ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [False, False, False,  True,  True,  True,  True,  True,  True, False],
         [False, False, False,  True,  True,  True,  True,  True,  True, False],
         [False, False, False,  True,  True,  True,  True,  True,  True, False],
         [False, False, False, False, False, False,  True,  True,  True,  True]]])

1.6 mask_finished_scores

1. 函数代码

def mask_finished_scores(score: torch.Tensor,
                         flag: torch.Tensor) -> torch.Tensor:
    """
    If a sequence is finished, we only allow one alive branch. This function
    aims to give one branch a zero score and the rest -inf score.

    Args:
        score (torch.Tensor): A real value array with shape
            (batch_size * beam_size, beam_size).
        flag (torch.Tensor): A bool array with shape
            (batch_size * beam_size, 1).

    Returns:
        torch.Tensor: (batch_size * beam_size, beam_size).
    """
    beam_size = score.size(-1)
    zero_mask = torch.zeros_like(flag, dtype=torch.bool)
    if beam_size > 1:
        unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
                               dim=1)
        finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
                             dim=1)
    else:
        unfinished = zero_mask
        finished = flag
    score.masked_fill_(unfinished, -float('inf'))
    score.masked_fill_(finished, 0)
    return score

2. 函数功能

当一个sequence完成之后,只允许保存一个alive branch。这个函数的主要作用是将一个branch的score置为0,而其他的branch的score置为-inf

3. 函数参数

  • score:tensor,形状为(batch_size * beam_size , beam_size)的实数数组
  • flag:tensor,形状为(batch_size * beam_size , 1)的bool型数组

4. 函数返回值

返回一个形状为(batch_size * beam_size , beam_size)的Tensor

5. 使用示例

import torch

def mask_finished_scores(score: torch.Tensor,
                         flag: torch.Tensor) -> torch.Tensor:
    """
    If a sequence is finished, we only allow one alive branch. This function
    aims to give one branch a zero score and the rest -inf score.

    Args:
        score (torch.Tensor): A real value array with shape
            (batch_size * beam_size, beam_size).
        flag (torch.Tensor): A bool array with shape
            (batch_size * beam_size, 1).

    Returns:
        torch.Tensor: (batch_size * beam_size, beam_size).
    """
    beam_size = score.size(-1)
    zero_mask = torch.zeros_like(flag, dtype=torch.bool)
    if beam_size > 1:
        unfinished = torch.cat((zero_mask, flag.repeat([1, beam_size - 1])),
                               dim=1)
        finished = torch.cat((flag, zero_mask.repeat([1, beam_size - 1])),
                             dim=1)
    else:
        unfinished = zero_mask
        finished = flag
    score.masked_fill_(unfinished, -float('inf'))
    score.masked_fill_(finished, 0)
    return score



if __name__ == '__main__':
    batch_size = 1
    beam_size = 5
    score = torch.rand(size=(5,5))
    flag = torch.tensor([True,False,True,True,False]).unsqueeze(1)

    out = mask_finished_scores(score=score,flag=flag)
    print(out.shape)
    print(out)

输出

torch.Size([5, 5])
tensor([[0.0000,   -inf,   -inf,   -inf,   -inf],
        [0.6803, 0.0728, 0.4417, 0.4770, 0.1615],
        [0.0000,   -inf,   -inf,   -inf,   -inf],
        [0.0000,   -inf,   -inf,   -inf,   -inf],
        [0.5874, 0.7472, 0.1647, 0.8151, 0.6123]])

1.7 mask_finished_preds

1. 函数代码

def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
                        eos: int) -> torch.Tensor:
    """
    If a sequence is finished, all of its branch should be <eos>

    Args:
        pred (torch.Tensor): A int array with shape
            (batch_size * beam_size, beam_size).
        flag (torch.Tensor): A bool array with shape
            (batch_size * beam_size, 1).

    Returns:
        torch.Tensor: (batch_size * beam_size).
    """
    beam_size = pred.size(-1)
    finished = flag.repeat([1, beam_size])
    return pred.masked_fill_(finished, eos)

2. 函数功能

如果sequence完成,那么该序列的所有branch都应该为<eos>

3. 函数参数

  • pred:tensor,形状为(batch_size * beam_size , beam_size)的int型数组
  • flag:tensor,形状为(batch_size * beam_size , 1)的bool型数组
  • eos:int,<eos>对应的id

4. 函数返回值

返回一个形状为(batch_size , beam_size)的Tensor

5. 使用示例

import torch

def mask_finished_preds(pred: torch.Tensor, flag: torch.Tensor,
                        eos: int) -> torch.Tensor:
    """
    If a sequence is finished, all of its branch should be <eos>

    Args:
        pred (torch.Tensor): A int array with shape
            (batch_size * beam_size, beam_size).
        flag (torch.Tensor): A bool array with shape
            (batch_size * beam_size, 1).

    Returns:
        torch.Tensor: (batch_size * beam_size).
    """
    beam_size = pred.size(-1)
    finished = flag.repeat([1, beam_size])
    return pred.masked_fill_(finished, eos)


if __name__ == '__main__':
    batch_size = 1
    beam_size = 5
    score = torch.randint(low=1,high=5,size=(5,5))
    flag = torch.tensor([True,False,True,True,False]).unsqueeze(1)

    out = mask_finished_preds(pred=score,flag=flag,eos=6)
    print(out.shape)
    print(out)

输出

torch.Size([5, 5])
tensor([[6, 6, 6, 6, 6],
        [2, 3, 2, 1, 2],
        [6, 6, 6, 6, 6],
        [6, 6, 6, 6, 6],
        [1, 3, 2, 2, 4]])