1 torch.chunk参数详解与使用
1.1 torch.chunk
1.函数形式
torch.chunk(input, chunks, dim=0) → List of Tensors
2.函数功能
将输入Tensor拆分为特定数量的块。
如果给定维度dim上的Tensor大小不能够被整除,则最后一个块会小于之前的块。
3.函数参数
- input:Tensor类型,需要被拆分的Tensor
- chunks:int类型,需要拆分的chunk数量
- dim:int类型,拆分张量的维度
4.函数返回值
返回拆分好的Tensor列表
1.2 torch.chunk的使用
1.2.1 将torch.chunk应用在二维Tensor上
import torch
if __name__ == '__main__':
input = torch.arange(0,16).view(4,4)
chunk_dim0 = torch.chunk(input,4,0)
chunk_dim1 = torch.chunk(input,4,1)
print('input:{0} \n'.format(input))
print('chunk_dim0:{0} \n'.format(chunk_dim0))
print('chunk_dim1:{0} \n'.format(chunk_dim1))
输出
input:tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]])
chunk_dim0:(tensor([[0, 1, 2, 3]]), tensor([[4, 5, 6, 7]]), tensor([[ 8, 9, 10, 11]]), tensor([[12, 13, 14, 15]]))
chunk_dim1:(tensor([[ 0],
[ 4],
[ 8],
[12]]), tensor([[ 1],
[ 5],
[ 9],
[13]]), tensor([[ 2],
[ 6],
[10],
[14]]), tensor([[ 3],
[ 7],
[11],
[15]]))
1.2.2 将torch.chunk应用在三维Tensor上
import torch
if __name__ == '__main__':
input = torch.arange(0,64).view(4,4,4)
chunk_dim0 = torch.chunk(input,4,0)
chunk_dim1 = torch.chunk(input,4,1)
chunk_dim2 = torch.chunk(input, 4, 2)
print('input:{0} \n'.format(input))
print('chunk_dim0:{0} \n'.format(chunk_dim0))
print('chunk_dim1:{0} \n'.format(chunk_dim1))
print('chunk_dim2:{0} \n'.format(chunk_dim2))
输出
input:tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]],
[[48, 49, 50, 51],
[52, 53, 54, 55],
[56, 57, 58, 59],
[60, 61, 62, 63]]])
chunk_dim0:(tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]]), tensor([[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]]]), tensor([[[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]]]), tensor([[[48, 49, 50, 51],
[52, 53, 54, 55],
[56, 57, 58, 59],
[60, 61, 62, 63]]]))
chunk_dim1:(tensor([[[ 0, 1, 2, 3]],
[[16, 17, 18, 19]],
[[32, 33, 34, 35]],
[[48, 49, 50, 51]]]), tensor([[[ 4, 5, 6, 7]],
[[20, 21, 22, 23]],
[[36, 37, 38, 39]],
[[52, 53, 54, 55]]]), tensor([[[ 8, 9, 10, 11]],
[[24, 25, 26, 27]],
[[40, 41, 42, 43]],
[[56, 57, 58, 59]]]), tensor([[[12, 13, 14, 15]],
[[28, 29, 30, 31]],
[[44, 45, 46, 47]],
[[60, 61, 62, 63]]]))
chunk_dim2:(tensor([[[ 0],
[ 4],
[ 8],
[12]],
[[16],
[20],
[24],
[28]],
[[32],
[36],
[40],
[44]],
[[48],
[52],
[56],
[60]]]), tensor([[[ 1],
[ 5],
[ 9],
[13]],
[[17],
[21],
[25],
[29]],
[[33],
[37],
[41],
[45]],
[[49],
[53],
[57],
[61]]]), tensor([[[ 2],
[ 6],
[10],
[14]],
[[18],
[22],
[26],
[30]],
[[34],
[38],
[42],
[46]],
[[50],
[54],
[58],
[62]]]), tensor([[[ 3],
[ 7],
[11],
[15]],
[[19],
[23],
[27],
[31]],
[[35],
[39],
[43],
[47]],
[[51],
[55],
[59],
[63]]]))
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – torch.chunk参数详解与使用
原文链接:https://www.stubbornhuang.com/2215/
发布于:2022年07月22日 10:31:32
修改于:2023年06月25日 20:51:39
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50