Pytorch – torch.unsqueeze和torch.squeeze函数
Pytorch中,unsqueeze
和squeeze
为两个对应的反操作函数,其中,unsqueeze
主要用于为输入张量升维,squeeze
主要用于给张量降维,两者的具体用法可以参考下文。
1 unsqueeze
形式
torch.unsqueeze(input, dim)
或者
Tensor.unsqueeze(dim)
功能
在输入张量input的指定索引插入1维度。
dim的数值处于[-input.dim()-1,input.dim()+1]的区间内,如果dim的值为负,则计算为dim=dim+input.dim()+1。
参数
- input:输入的张量
- dim:插入维度的索引
使用示例:
import torch
if __name__ == '__main__':
a = torch.randn(size=(5, 3, 224, 224))
print(a.shape)
b = torch.unsqueeze(a, 1)
print(b.shape)
c = b.unsqueeze(3)
print(c.shape)
输出
torch.Size([5, 3, 224, 224])
torch.Size([5, 1, 3, 224, 224])
torch.Size([5, 1, 3, 1, 224, 224])
2 squeeze
形式
torch.squeeze(input, dim=None)
或者
Tensor.squeeze(dim=None)
功能
删除input输入张量中所有大小为1的维度。
比如说,输入张量的维度为(A \times 1 \times B \times C \times 1 \times D),那么经过squeeze函数处理之后,输出张量的维度为(A \times B \times C \times D)。
如果指定了dim,那么squeeze算子仅仅只会在给定的dim上操作,比如输入张量维度为(A \times 1 \times B),squeeze(input,0)
不会改变输入张量,只有使用squeeze(input,1)
才会将输出张量(A \times B)。
参数
- input:输入张量
- dim:需要操作的维度索引,如果被指定,则只会在该维度上应用squeeze操作
使用示例:
import torch
if __name__ == '__main__':
a = torch.randn(size=(5, 1, 224, 224, 1))
print(a.shape)
b = torch.squeeze(a)
print(b.shape)
c = b.unsqueeze(3)
print(c.shape)
d = c.squeeze(3)
print(d.shape)
输出
torch.Size([5, 1, 224, 224, 1])
torch.Size([5, 224, 224])
torch.Size([5, 224, 224, 1])
torch.Size([5, 224, 224])
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – torch.unsqueeze和torch.squeeze函数
原文链接:https://www.stubbornhuang.com/2440/
发布于:2022年12月08日 15:01:13
修改于:2023年06月21日 17:45:04
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50