1 Beam Search Algorithm
在本文中会尽量以通俗易懂的方式介绍Beam Search Algorithm的原理。
在机器翻译领域(Encoder-Decoder模型),将一种语言翻译成另外一种语言时,我们首先需要对源语言的单词序列进行编码,然后通过深度学习模型训练和推理得到中间输出,然后我们需要使用解码器将中间数据解码为目标语言的单词序列,在这一过程中,通常使用Beam Search Algorithm作为解码的核心算法,找出目标语言的单词序列。
1.1 Beam Search Algorithm是什么?
Beam Search Algorithm是一种贪心搜索算法,类似于广度优先搜索(BFS,Breadth First Search)和最佳优先搜索(BeFS,Best-First-Search),我们可以将广度优先搜索BFS和最佳优先搜索BeFS看做是Beam Search的一种特殊形式。
我们一个简单的示例说明下Beam Search算法,假设现在我们有一个图G,给定一个目标节点P_{target},现在我们需要通过遍历图G到达节点P_{target}。首先我们从根节点开始遍历图G,第一步,我们搜索所有可能的节点并从中选取\beta个可能性最大的节点;第二步,基于第一步选取的\beta个节点下再次寻找下一批\beta个可能性最大的节点,直到找到目标节点P_{target}。我们将\beta称为beam width,而我们通过上述方式找到的节点路径就成为我们找到目标点的最佳节点路径。
我们可以注意到,当\beta = 1时,那么这种情况下Beam Search算法就是最佳优先搜索算法BeFS,因为当\beta = 1时,我们总是在每一步中找到最有可能性的节点作为我们的最佳节点,然后进行下一步的搜索;而当\beta = \infty时,这种情况下Beam Search算法就会成为广度优先搜索算法BFS,因为当\beta = \infty时,我们会将所有的点作为我们的最佳点,然后进行下一步的搜索。
1.2 Beam Search Algorithm的伪代码
下图是Beam Search Algorithm的伪代码,
1.3 Beam Search Algorithm搜索的简单示例
在本节中,我们通过一个简单的示例详细描述下Beam Search Algorithm的搜索过程。
参考上图,我们假设beam width\beta = 2,即每次获取两个最佳节点。我们从start节点开始,第1步我们从所有的词中找到了两个最佳词arrived和the;然后以这两个词为基础,继续进行beam width\beta = 2的搜索,那么可选词序列为[arrived the,arrived witch,the green,the witch]
,从这些可能的词序列中我们选择两个可能性最大的词序列[the green,the witch]
;在上一步结果的基础上,我们在第三步中又选择最大化当前概率的词序列[the green witch,the width who]
;重复以上的步骤,直到整个序列都搜索完毕,这个时候我们会得到在beam width\beta = 2时的可能性最大的翻译序列。
beam width的宽度对最后的翻译结果以及搜索效率有着很大的影响,当beam width过小时,通常会忽略了自然语言中的长期依赖性,从而获取准确性不大的翻译序列;当beam width过大时,会造成需要搜索的路径数量过大,需要消耗更多的内存以及更多的算力。
1.4 Beam Search Algorithm的优缺点
与最佳优先搜索相比,Beam Search Algorithm的一个优点是需要的内存更少。因为Beam Search Algorithm不需要将连续节点存储在队列中。另一方面,它仍然存在BeFS的一些问题。第一个是它的搜索是不完整,这导致可能根本找不到正确的解决方案。第二个问题是它也不是最优的。因此,当它返回一个解决方案时,它可能不是最佳解决方案。
1.5 Beam Search Algorithm中的Beam Size/Beam Width的作用
Beam Size/Beam Width可以看做是Beam Search Algorithm的超参数,这个超参数的设置一般取决于我们希望Beam Search Algorithm的计算速度有多快、内存消耗量以及得到何等准确性的结果。
更大的Beam Size/Beam Width可以获取更加准确的搜索结果,但是也需要更长的计算时间与消耗更多的内存,所以在设置Beam Size/Beam Width时,一般需要在准确性和性能开销上保持平衡。
1.6 代码实现
from math import log
import numpy as np
# beam search
def beam_search_decoder(data, beam_size):
sequences = [[list(), 0.0]]
# 遍历每一个序列
for row in data:
all_candidates = list()
# 在下一个序列中找到候选者
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - log(row[j])]
all_candidates.append(candidate)
# 根据分数排序所有的候选者
ordered = sorted(all_candidates, key=lambda tup:tup[1])
# 选择beam_size个最大的
sequences = ordered[:beam_size]
return sequences
if __name__ == '__main__':
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = np.array(data)
result = beam_search_decoder(data,3)
for seq in result:
print(seq)
输出
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 6.931471805599453]
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 7.154615356913663]
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 7.154615356913663]
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:深度学习 – 通俗理解Beam Search Algorithm算法
原文链接:https://www.stubbornhuang.com/2218/
发布于:2022年07月14日 9:41:26
修改于:2023年06月25日 20:56:37
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50