在语音识别、OCR文字识别领域,我们在推理的最后一步就是从预测的概率矩阵中使用CTC解码算法找到可能性最大的序列。而常用的CTC解码算法一般有Greedy Search Decode(贪心搜索)、Beam Search Decode(束搜索)、Prefix Beam Search Decode(前缀束搜索)等,其中又以Greedy Search Decode(贪心搜索)和Prefix Beam Search Decode(前缀束搜索)使用的最多,本文将使用Python代码逐一实现上述三种算法。

1 Greedy Search Decode(贪心搜索)

1.1 原理

Greedy Search Decode(贪心搜索)主要是找到每一个时间步中概率最大的字符,然后去除相邻重复的字符,最后去除blank字符。

1.2 代码

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def ctc_greedy_search_decode(probs,blank = 0):
    index_list = np.argmax(probs, axis=1)

    # 去除相邻重复
    index_list = [key for key,group in groupby(index_list)]

    # 去除blank
    index_list = [*filter(lambda x: x != blank, index_list)]

    return index_list

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_greedy_search_decode(probs)

    print('label:{}'.format(res))

输出

label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12]

2 Beam Search Decode(束搜索)

2.1 原理

Greedy Search Decode(贪心搜索)只能找到最优路径,而不能找到其他优选路径。而Beam Search Decode(束搜索)可根据设置的beam size找到beam size条优选路径,Beam Search Decode(束搜索)的具体原理可参考我的文章:深度学习 – 通俗理解Beam Search Algorithm算法

2.2 代码

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp


def ctc_beam_search_decode(probs, beam_size=10 ,blank = 0):
    T, V = probs.shape
    log_probs = np.log(probs)

    beam = [([], 0)]
    for t in range(T):
        new_beam = []
        for prefix, score in beam:
            for i in range(V):
                new_prefix = prefix + [i]
                new_score = score + log_probs[t, i]

                new_beam.append((new_prefix, new_score))

        # top beam_size
        new_beam.sort(key=lambda x: x[1], reverse=True)
        beam = new_beam[:beam_size]

    # 去除相邻重复和blank
    res = []
    for label, score in beam:
        # 去除相邻重复
        index_list = [key for key, group in groupby(label)]

        # 去除blank
        index_list = [*filter(lambda x: x != blank, index_list)]

        res.append((index_list,score))

    return res

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_beam_search_decode(probs,beam_size = 3,blank=0)
    for label, score in res:
        print('label:{},score={}'.format(label,score))

输出

label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.8869170531208
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 6, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.88710645664252
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 12],score=-51.89227568133298

3 Prefix Beam Search Decode(前缀束搜索)

3.1 原理

上述的Beam Search Decode(束搜索)存在的一个问题是,在找到的beam size条路径中,可能存在多条重复路径(去相邻重复,去blank操作之后),这很大程度上影响了搜索结果的多样性。而Prefix Beam Search Decode(前缀束搜索)在搜索的过程中,在遇到blank时不改变之前的搜索结果,并且不断合并相同的前缀,这种方法很大程度上减少了重复路径,而这种方法也是目前CTC解码算法的最优选方法。其原理具体可参考这篇论文:First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs

论文中给出了算法的伪代码,
深度学习 – Python实现CTC Decode解码算法Greedy Search Decode,Beam Search Decode,Prefix Beam Search Decode-StubbornHuang Blog

在下方的示例代码中用到了一个trick,一些文章中也没有说的很清楚,即函数logsumexp,当当前时间步的概率x和下一步的概率y都为一个非常小的值时,我们求解概率和x+y的数值可能会下溢,所以我们可以将x+y转移到对数空间避免下溢,我们知道

log(x \cdot y) = log x+log y

如果我们需要将log(x+y)转移到对数空间该怎么做呢?可以通过以下的方式

log(x+y) = log(exp(logx)+exp(logy))

log(exp(logx)+exp(logy))有一个专有的名字:log-sum-exp

log\text{-}sum\text{-}exp(u,v)=log(exp(u)+exp(v))

所以,log(x+y)可以写成

log(x+y) = log\text{-}sum\text{-}exp(logx,logy)

而我们可以通过以下公式计算log\text{-}sum\text{-}exp(u,v)

log\text{-}sum\text{-}exp(u,v) = max(u,v)+log(exp(u-max(u,v)) + exp(v - max(u,v)))

而这就是下列代码中logsumexp写成那种形式的原因。

更加详细的解释可参考以下文章:

3.3 代码

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def ctc_prefix_beam_search_decode(prob, beam_size=10, blank=0):
    T, V = prob.shape
    log_prob = np.log(prob)

    beam = [(tuple(), (0, NEG_INF))]  # blank, non-blank
    for t in range(T):  # for every timestep
        new_beam = defaultdict(lambda: (NEG_INF, NEG_INF))

        for prefix, (p_b, p_nb) in beam:
            for i in range(V):  # for every state
                p = log_prob[t, i]

                if i == blank:  # propose a blank
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                else:  # extend with non-blank
                    end_t = prefix[-1] if prefix else None

                    # exntend current prefix
                    new_prefix = prefix + (i,)
                    new_p_b, new_p_nb = new_beam[new_prefix]
                    if i != end_t:
                        new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    else:
                        new_p_nb = logsumexp(new_p_nb, p_b + p)
                    new_beam[new_prefix] = (new_p_b, new_p_nb)

                    # keep current prefix
                    if i == end_t:
                        new_p_b, new_p_nb = new_beam[prefix]
                        new_p_nb = logsumexp(new_p_nb, p_nb + p)
                        new_beam[prefix] = (new_p_b, new_p_nb)

        # top beam_size
        beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]

    return beam

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_prefix_beam_search_decode(probs,beam_size = 3,blank=0)
    for beam in res:
        print('label:{},score={}'.format(beam[0],logsumexp(*beam[1])))

输出

label:(12, 7, 9, 19, 2, 15, 12, 11, 3),score=-43.130412256239644
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 12),score=-43.59912015650705
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 11),score=-43.61975284105764

4 对统一概率矩阵三种CTC解码算法的执行结果

import numpy as np
from itertools import groupby
import math
from collections import defaultdict

NEG_INF = -float("inf")

def softmax(logits):
    max_value = np.max(logits, axis=1, keepdims=True)
    exp = np.exp(logits - max_value)
    exp_sum = np.sum(exp, axis=1, keepdims=True)
    dist = exp / exp_sum
    return dist

def logsumexp(*args):
  """
  Stable log sum exp.
  """
  if all(a == NEG_INF for a in args):
      return NEG_INF
  a_max = max(args)
  lsp = math.log(sum(math.exp(a - a_max)
                      for a in args))
  return a_max + lsp

def ctc_greedy_search_decode(probs,blank = 0):
    index_list = np.argmax(probs, axis=1)

    # 去除相邻重复
    index_list = [key for key,group in groupby(index_list)]

    # 去除blank
    index_list = [*filter(lambda x: x != blank, index_list)]

    return index_list


def ctc_beam_search_decode(probs, beam_size=10 ,blank = 0):
    T, V = probs.shape
    log_probs = np.log(probs)

    beam = [([], 0)]
    for t in range(T):
        new_beam = []
        for prefix, score in beam:
            for i in range(V):
                new_prefix = prefix + [i]
                new_score = score + log_probs[t, i]

                new_beam.append((new_prefix, new_score))

        # top beam_size
        new_beam.sort(key=lambda x: x[1], reverse=True)
        beam = new_beam[:beam_size]

    # 去除相邻重复和blank
    res = []
    for label, score in beam:
        # 去除相邻重复
        index_list = [key for key, group in groupby(label)]

        # 去除blank
        index_list = [*filter(lambda x: x != blank, index_list)]

        res.append((index_list,score))

    return res

def ctc_prefix_beam_search_decode(prob, beam_size=10, blank=0):
    T, V = prob.shape
    log_prob = np.log(prob)

    beam = [(tuple(), (0, NEG_INF))]  # blank, non-blank
    for t in range(T):  # for every timestep
        new_beam = defaultdict(lambda: (NEG_INF, NEG_INF))

        for prefix, (p_b, p_nb) in beam:
            for i in range(V):  # for every state
                p = log_prob[t, i]

                if i == blank:  # propose a blank
                    new_p_b, new_p_nb = new_beam[prefix]
                    new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
                    new_beam[prefix] = (new_p_b, new_p_nb)
                    continue
                else:  # extend with non-blank
                    end_t = prefix[-1] if prefix else None

                    # exntend current prefix
                    new_prefix = prefix + (i,)
                    new_p_b, new_p_nb = new_beam[new_prefix]
                    if i != end_t:
                        new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
                    else:
                        new_p_nb = logsumexp(new_p_nb, p_b + p)
                    new_beam[new_prefix] = (new_p_b, new_p_nb)

                    # keep current prefix
                    if i == end_t:
                        new_p_b, new_p_nb = new_beam[prefix]
                        new_p_nb = logsumexp(new_p_nb, p_nb + p)
                        new_beam[prefix] = (new_p_b, new_p_nb)

        # top beam_size
        beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
        beam = beam[:beam_size]

    return beam

if __name__ == '__main__':
    np.random.seed(11)
    time_step = 20
    vocab_size = 20

    # 假设时间步数为20步,词汇表中包含的词汇为20个
    probs = np.random.rand(time_step,vocab_size)

    # 输入ctc_decode需要进行softmax
    probs = softmax(probs)

    res = ctc_greedy_search_decode(probs)

    print('ctc_greedy_search_decode:\nlabel:{}\n'.format(res))

    print('ctc_beam_search_decode:')
    res = ctc_beam_search_decode(probs,beam_size = 3,blank=0)
    for label, score in res:
        print('label:{},score={}'.format(label,score))
    print('\n')


    print('ctc_prefix_beam_search_decode')
    res = ctc_prefix_beam_search_decode(probs,beam_size = 3,blank=0)
    for beam in res:
        print('label:{},score={}'.format(beam[0],logsumexp(*beam[1])))
    print('\n')

输出

ctc_greedy_search_decode:
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12]

ctc_beam_search_decode:
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.8869170531208
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 6, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.88710645664252
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 12],score=-51.89227568133298


ctc_prefix_beam_search_decode
label:(12, 7, 9, 19, 2, 15, 12, 11, 3),score=-43.130412256239644
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 12),score=-43.59912015650705
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 11),score=-43.61975284105764

参考链接