1 torch.topk
形式
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
或者
Tensor.topk(k, dim=None, largest=True, sorted=True)
功能
返回输入张量input在指定dim上k个最大的元素。
如果没有指定dim,则默认选择最后一个维度。
如果largest为false,则返回最小的元素。
参数
- input(Tensor):输入张量
- k(int):top-k中的k,k个最大的元素
- dim:排序的维度
- largest(bool):如果为false,则返回最小的元素,默认为true
- sorted(bool):如果为true,则返回排序好的k个元素,默认为true
返回值
返回一个(value,indices)
的元素,其中value为dim上最大k个元素的值,indices为dim上最大k个元素对应的索引
使用示例
import torch
if __name__ == '__main__':
a = torch.randn(size=(4, 4))
print(a)
value, indices = torch.topk(a, 2)
print("==========value============")
print(value)
print("==========indices============")
print(indices)
输出
tensor([[ 0.8897, -0.1849, -0.0502, 0.6197],
[-1.6369, 0.9490, -0.8909, -0.5208],
[-1.2246, 1.5075, 0.6353, -0.3103],
[-0.5476, -0.9881, -0.5171, -1.0863]])
==========value============
tensor([[ 0.8897, 0.6197],
[ 0.9490, -0.5208],
[ 1.5075, 0.6353],
[-0.5171, -0.5476]])
==========indices============
tensor([[0, 3],
[1, 3],
[1, 2],
[2, 0]])
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – torch.topk参数详解与使用
原文链接:https://www.stubbornhuang.com/2454/
发布于:2022年12月15日 16:54:37
修改于:2023年06月21日 17:42:54
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50