1 wenet非流式流式混合训练机制

wenet实现了语音识别非流式与流式混合训练的机制。通过细读源码,其主要是通过动态修改网络的Encoder层(在wenet中主要使用了TransformerEncoder和Conformer)的attention mask来影响Encoder层中Self-Attention的计算结果,从而实现一次训练既可以实现非流式语音识别也可以实现流式语音识别。

在wenet中提出了dynamic chunk的训练机制,在训练过程中,每一个batch都使用dynamic chunk size进行训练,换句话说,每一个batch所使用的chunk size是不同的,是动态变化的。

记当前batch的最大的序列长度为L,而dynamic chunk size在1到该batch的最大序列长度L的范围内随机取值,如果随机值大于\frac{L}{2},则将chunk size的值取为L,如果随机值小于或者等于\frac{L}{2},则chunk size从[1,25]的范围内随机取值。如chunk size取值为batch中最大序列长度L,则表明使用全部上下文注意,这种则作为非流式训练的组成部分;而另一种将chunk size取值为[1,25],则表示使用部分上下文注意,这种取值方式则作为流式训练的组成部分。

而其核心代码主要是wenet/utils/mask.py中的add_optional_chunk_mask函数,下面将通过这个函数的代码梳理下wenet如何通过该函数构造出不同的attention mask。

1.1 代码梳理

add_optional_chunk_mask源代码

def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor,
                            use_dynamic_chunk: bool,
                            use_dynamic_left_chunk: bool,
                            decoding_chunk_size: int, static_chunk_size: int,
                            num_decoding_left_chunks: int):
    """ Apply optional mask for encoder.

    Args:
        xs (torch.Tensor): padded input, (B, L, D), L for max length
        mask (torch.Tensor): mask for xs, (B, 1, L)
        use_dynamic_chunk (bool): whether to use dynamic chunk or not
        use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
            training.
        decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
            0: default for training, use random dynamic chunk.
            <0: for decoding, use full chunk.
            >0: for decoding, use fixed chunk size as set.
        static_chunk_size (int): chunk size for static chunk training/decoding
            if it's greater than 0, if use_dynamic_chunk is true,
            this parameter will be ignored
        num_decoding_left_chunks: number of left chunks, this is for decoding,
            the chunk size is decoding_chunk_size.
            >=0: use num_decoding_left_chunks
            <0: use all left chunks

    Returns:
        torch.Tensor: chunk mask of the input xs.
    """
    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)
    else:
        chunk_masks = masks
    return chunk_masks

add_optional_chunk_mask函数的函数参数详情如下:

  • xs:tensor,已padding后的Tensor,形状为(B,L,D),其中L为最大长度
  • masks:tensor,xs对应的padding mask,形状为(B,1,L)
  • use_dynamic_chunk:是否使用dynamic chunk
  • use_dynamic_left_chunk:是否使用dynamic left chunk
  • decoding_chunk_size:dynamic chunk的decode chunk大小,在模型训练时默认为0,使用随机的dynamic chunk;如果decode_chunk_size小于0,使用full chunk;如果decode_chunk_size大于0,使用固定的decode_chunk_size
  • static_chunk_size:在static chunk训练和解码的chunk size,当use_dynamic_chunk为true时,不管static_chunk_size设置为什么值都会被忽略
  • num_decoding_left_chunks:在解码时需用到多少个left chunk,解码的chunk_size为decode_chunk_size,如果num_decoding_left_chunks小于0,则使用全部的left chunk;如果num_deocding_left_chunks大于等于0,则使用num_deocding_left_chunks个left chunk

提供了3种Encoder层attention mask的生成方式,

  • 第一种是不计算attention mask,直接返回原有的padding mask
  • 第二种是生成指定的static_chunk_size的attention mask
  • 第三种则是dynamic chunk的attention mask

1.1.1 不计算attention mask

相关代码

    else:
        chunk_masks = masks

代码很简洁,就是直接将输入的padding mask直接赋值给chunk mask进行返回,不进行任何操作。

在此种方式下就是原生的TransformerEncoder或者ConformerEncoder的编码方式,所有的数据都参与编码过程。

1.1.2 static_chunk_size的attention mask

相关代码

    elif static_chunk_size > 0:
        num_left_chunks = num_decoding_left_chunks
        chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)

在static_chunk_size的情况下,不需要向dynamic模式那样动态改变chunk size,所以相对来说比较简单,在参数num_left_chunksstatic_chunk_size确定的情况下,直接使用subsequent_chunk_mask函数求解出chunk mask,然后与输入的padding mask做&操作得到最后的mask。subsequent_chunk_mask函数的用法可以参考我另一篇博文深度学习 – 语音识别框架wenet源码wenet/utils/mask.py中的mask机制,该函数的主要作用是使用设定的chunk_size创建形状为(size,size)的mask Tensor。

1.1.3 dynamic_chunk_size的attention mask

相关代码

    # Whether to use chunk mask or not
    if use_dynamic_chunk:
        max_len = xs.size(1)
        if decoding_chunk_size < 0:
            chunk_size = max_len
            num_left_chunks = -1
        elif decoding_chunk_size > 0:
            chunk_size = decoding_chunk_size
            num_left_chunks = num_decoding_left_chunks
        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)

终于来到正题dynamic chunk的代码了。

上述代码中decoding_chunk_size不等于0主要是用于解码;而训练时,decoding_chunk_size等于0,训练的关键代码为

        else:
            # chunk size is either [1, 25] or full context(max_len).
            # Since we use 4 times subsampling and allow up to 1s(100 frames)
            # delay, the maximum frame is 100 / 4 = 25.
            chunk_size = torch.randint(1, max_len, (1, )).item()
            num_left_chunks = -1
            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1
                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()
        chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
                                            num_left_chunks,
                                            xs.device)  # (L, L)
        chunk_masks = chunk_masks.unsqueeze(0)  # (1, L, L)
        chunk_masks = masks & chunk_masks  # (B, L, L)

上述代码中,max_len为当前batch的序列最大长度,首先通过

chunk_size = torch.randint(1, max_len, (1, )).item()

在[1,max_len]的范围内随机取值,然后将随机取值进行判断

            if chunk_size > max_len // 2:
                chunk_size = max_len
            else:
                chunk_size = chunk_size % 25 + 1

如果随机取值大于序列长度的一半,则将chunk_size设置为序列的最大长度值;如果如果随机取值小于或者等于序列长度的一半,则取chunk_size = chunk_size % 25 + 1,即在[1,25]的范围内随机取值。如果chunk_size设置为序列的最大长度值则为非流式训练,如果chunk_size在[1,25]的范围内随机取值则为流式训练。

同时,为了减少模型在解码时不会因为chunk size设置的值所带来识别效果大幅变差的情况,使用了dynamic left chunk,相关的代码如下,

                if use_dynamic_left_chunk:
                    max_left_chunks = (max_len - 1) // chunk_size
                    num_left_chunks = torch.randint(0, max_left_chunks,
                                                    (1, )).item()

上述代码的核心思想是,如果我们使用了dynamic left chunk,则先根据chunk_sizemax_len算出最大的左侧chunk数量max_left_chunk,然后从[0, max_left_chunk]的范围内随机取值作为num_left_chunks

等上述所有过程处理完毕,则使用subsequent_chunk_mask函数求解出chunk mask,然后与输入的padding mask做&操作得到最后的mask。

2 总结

静下心来细读wenet的源码,还是收获颇多,终于解答了自己“为什么wenet既可以实现流式语音识别又可以实现非流式语音识别?”的疑问,它的这种设计对需要进行流式识别的领域有诸多启发,还是感谢wenet团队的无私奉献。