深度学习 – Python实现CTC Decode解码算法Greedy Search Decode,Beam Search Decode,Prefix Beam Search Decode
在语音识别、OCR文字识别领域,我们在推理的最后一步就是从预测的概率矩阵中使用CTC解码算法找到可能性最大的序列。而常用的CTC解码算法一般有Greedy Search Decode(贪心搜索)、Beam Search Decode(束搜索)、Prefix Beam Search Decode(前缀束搜索)等,其中又以Greedy Search Decode(贪心搜索)和Prefix Beam Search Decode(前缀束搜索)使用的最多,本文将使用Python代码逐一实现上述三种算法。
1 Greedy Search Decode(贪心搜索)
1.1 原理
Greedy Search Decode(贪心搜索)主要是找到每一个时间步中概率最大的字符,然后去除相邻重复的字符,最后去除blank字符。
1.2 代码
import numpy as np
from itertools import groupby
import math
from collections import defaultdict
NEG_INF = -float("inf")
def softmax(logits):
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def ctc_greedy_search_decode(probs,blank = 0):
index_list = np.argmax(probs, axis=1)
# 去除相邻重复
index_list = [key for key,group in groupby(index_list)]
# 去除blank
index_list = [*filter(lambda x: x != blank, index_list)]
return index_list
if __name__ == '__main__':
np.random.seed(11)
time_step = 20
vocab_size = 20
# 假设时间步数为20步,词汇表中包含的词汇为20个
probs = np.random.rand(time_step,vocab_size)
# 输入ctc_decode需要进行softmax
probs = softmax(probs)
res = ctc_greedy_search_decode(probs)
print('label:{}'.format(res))
输出
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12]
2 Beam Search Decode(束搜索)
2.1 原理
Greedy Search Decode(贪心搜索)只能找到最优路径,而不能找到其他优选路径。而Beam Search Decode(束搜索)可根据设置的beam size找到beam size条优选路径,Beam Search Decode(束搜索)的具体原理可参考我的文章:深度学习 – 通俗理解Beam Search Algorithm算法。
2.2 代码
import numpy as np
from itertools import groupby
import math
from collections import defaultdict
NEG_INF = -float("inf")
def softmax(logits):
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def ctc_beam_search_decode(probs, beam_size=10 ,blank = 0):
T, V = probs.shape
log_probs = np.log(probs)
beam = [([], 0)]
for t in range(T):
new_beam = []
for prefix, score in beam:
for i in range(V):
new_prefix = prefix + [i]
new_score = score + log_probs[t, i]
new_beam.append((new_prefix, new_score))
# top beam_size
new_beam.sort(key=lambda x: x[1], reverse=True)
beam = new_beam[:beam_size]
# 去除相邻重复和blank
res = []
for label, score in beam:
# 去除相邻重复
index_list = [key for key, group in groupby(label)]
# 去除blank
index_list = [*filter(lambda x: x != blank, index_list)]
res.append((index_list,score))
return res
if __name__ == '__main__':
np.random.seed(11)
time_step = 20
vocab_size = 20
# 假设时间步数为20步,词汇表中包含的词汇为20个
probs = np.random.rand(time_step,vocab_size)
# 输入ctc_decode需要进行softmax
probs = softmax(probs)
res = ctc_beam_search_decode(probs,beam_size = 3,blank=0)
for label, score in res:
print('label:{},score={}'.format(label,score))
输出
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.8869170531208
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 6, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.88710645664252
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 12],score=-51.89227568133298
3 Prefix Beam Search Decode(前缀束搜索)
3.1 原理
上述的Beam Search Decode(束搜索)存在的一个问题是,在找到的beam size条路径中,可能存在多条重复路径(去相邻重复,去blank操作之后),这很大程度上影响了搜索结果的多样性。而Prefix Beam Search Decode(前缀束搜索)在搜索的过程中,在遇到blank时不改变之前的搜索结果,并且不断合并相同的前缀,这种方法很大程度上减少了重复路径,而这种方法也是目前CTC解码算法的最优选方法。其原理具体可参考这篇论文:First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs。
论文中给出了算法的伪代码,
在下方的示例代码中用到了一个trick,一些文章中也没有说的很清楚,即函数logsumexp
,当当前时间步的概率x和下一步的概率y都为一个非常小的值时,我们求解概率和x+y的数值可能会下溢,所以我们可以将x+y转移到对数空间避免下溢,我们知道
如果我们需要将log(x+y)转移到对数空间该怎么做呢?可以通过以下的方式
而log(exp(logx)+exp(logy))有一个专有的名字:log-sum-exp
,
所以,log(x+y)可以写成
而我们可以通过以下公式计算log\text{-}sum\text{-}exp(u,v),
而这就是下列代码中logsumexp
写成那种形式的原因。
更加详细的解释可参考以下文章:
3.3 代码
import numpy as np
from itertools import groupby
import math
from collections import defaultdict
NEG_INF = -float("inf")
def softmax(logits):
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def ctc_prefix_beam_search_decode(prob, beam_size=10, blank=0):
T, V = prob.shape
log_prob = np.log(prob)
beam = [(tuple(), (0, NEG_INF))] # blank, non-blank
for t in range(T): # for every timestep
new_beam = defaultdict(lambda: (NEG_INF, NEG_INF))
for prefix, (p_b, p_nb) in beam:
for i in range(V): # for every state
p = log_prob[t, i]
if i == blank: # propose a blank
new_p_b, new_p_nb = new_beam[prefix]
new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
continue
else: # extend with non-blank
end_t = prefix[-1] if prefix else None
# exntend current prefix
new_prefix = prefix + (i,)
new_p_b, new_p_nb = new_beam[new_prefix]
if i != end_t:
new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
else:
new_p_nb = logsumexp(new_p_nb, p_b + p)
new_beam[new_prefix] = (new_p_b, new_p_nb)
# keep current prefix
if i == end_t:
new_p_b, new_p_nb = new_beam[prefix]
new_p_nb = logsumexp(new_p_nb, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
# top beam_size
beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
beam = beam[:beam_size]
return beam
if __name__ == '__main__':
np.random.seed(11)
time_step = 20
vocab_size = 20
# 假设时间步数为20步,词汇表中包含的词汇为20个
probs = np.random.rand(time_step,vocab_size)
# 输入ctc_decode需要进行softmax
probs = softmax(probs)
res = ctc_prefix_beam_search_decode(probs,beam_size = 3,blank=0)
for beam in res:
print('label:{},score={}'.format(beam[0],logsumexp(*beam[1])))
输出
label:(12, 7, 9, 19, 2, 15, 12, 11, 3),score=-43.130412256239644
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 12),score=-43.59912015650705
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 11),score=-43.61975284105764
4 对统一概率矩阵三种CTC解码算法的执行结果
import numpy as np
from itertools import groupby
import math
from collections import defaultdict
NEG_INF = -float("inf")
def softmax(logits):
max_value = np.max(logits, axis=1, keepdims=True)
exp = np.exp(logits - max_value)
exp_sum = np.sum(exp, axis=1, keepdims=True)
dist = exp / exp_sum
return dist
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max)
for a in args))
return a_max + lsp
def ctc_greedy_search_decode(probs,blank = 0):
index_list = np.argmax(probs, axis=1)
# 去除相邻重复
index_list = [key for key,group in groupby(index_list)]
# 去除blank
index_list = [*filter(lambda x: x != blank, index_list)]
return index_list
def ctc_beam_search_decode(probs, beam_size=10 ,blank = 0):
T, V = probs.shape
log_probs = np.log(probs)
beam = [([], 0)]
for t in range(T):
new_beam = []
for prefix, score in beam:
for i in range(V):
new_prefix = prefix + [i]
new_score = score + log_probs[t, i]
new_beam.append((new_prefix, new_score))
# top beam_size
new_beam.sort(key=lambda x: x[1], reverse=True)
beam = new_beam[:beam_size]
# 去除相邻重复和blank
res = []
for label, score in beam:
# 去除相邻重复
index_list = [key for key, group in groupby(label)]
# 去除blank
index_list = [*filter(lambda x: x != blank, index_list)]
res.append((index_list,score))
return res
def ctc_prefix_beam_search_decode(prob, beam_size=10, blank=0):
T, V = prob.shape
log_prob = np.log(prob)
beam = [(tuple(), (0, NEG_INF))] # blank, non-blank
for t in range(T): # for every timestep
new_beam = defaultdict(lambda: (NEG_INF, NEG_INF))
for prefix, (p_b, p_nb) in beam:
for i in range(V): # for every state
p = log_prob[t, i]
if i == blank: # propose a blank
new_p_b, new_p_nb = new_beam[prefix]
new_p_b = logsumexp(new_p_b, p_b + p, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
continue
else: # extend with non-blank
end_t = prefix[-1] if prefix else None
# exntend current prefix
new_prefix = prefix + (i,)
new_p_b, new_p_nb = new_beam[new_prefix]
if i != end_t:
new_p_nb = logsumexp(new_p_nb, p_b + p, p_nb + p)
else:
new_p_nb = logsumexp(new_p_nb, p_b + p)
new_beam[new_prefix] = (new_p_b, new_p_nb)
# keep current prefix
if i == end_t:
new_p_b, new_p_nb = new_beam[prefix]
new_p_nb = logsumexp(new_p_nb, p_nb + p)
new_beam[prefix] = (new_p_b, new_p_nb)
# top beam_size
beam = sorted(new_beam.items(), key=lambda x: logsumexp(*x[1]), reverse=True)
beam = beam[:beam_size]
return beam
if __name__ == '__main__':
np.random.seed(11)
time_step = 20
vocab_size = 20
# 假设时间步数为20步,词汇表中包含的词汇为20个
probs = np.random.rand(time_step,vocab_size)
# 输入ctc_decode需要进行softmax
probs = softmax(probs)
res = ctc_greedy_search_decode(probs)
print('ctc_greedy_search_decode:\nlabel:{}\n'.format(res))
print('ctc_beam_search_decode:')
res = ctc_beam_search_decode(probs,beam_size = 3,blank=0)
for label, score in res:
print('label:{},score={}'.format(label,score))
print('\n')
print('ctc_prefix_beam_search_decode')
res = ctc_prefix_beam_search_decode(probs,beam_size = 3,blank=0)
for beam in res:
print('label:{},score={}'.format(beam[0],logsumexp(*beam[1])))
print('\n')
输出
ctc_greedy_search_decode:
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12]
ctc_beam_search_decode:
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.8869170531208
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 6, 15, 16, 7, 11, 18, 3, 1, 12],score=-51.88710645664252
label:[8, 16, 7, 9, 10, 8, 11, 2, 7, 15, 16, 7, 11, 18, 3, 12],score=-51.89227568133298
ctc_prefix_beam_search_decode
label:(12, 7, 9, 19, 2, 15, 12, 11, 3),score=-43.130412256239644
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 12),score=-43.59912015650705
label:(12, 7, 9, 19, 2, 15, 12, 11, 3, 11),score=-43.61975284105764
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:深度学习 – Python实现CTC Decode解码算法Greedy Search Decode,Beam Search Decode,Prefix Beam Search Decode
原文链接:https://www.stubbornhuang.com/2222/
发布于:2022年07月19日 8:53:27
修改于:2023年06月25日 20:53:23
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
52