1 torch.chunk参数详解与使用

1.1 torch.chunk

1.函数形式

torch.chunk(input, chunks, dim=0) → List of Tensors

2.函数功能

将输入Tensor拆分为特定数量的块。

如果给定维度dim上的Tensor大小不能够被整除,则最后一个块会小于之前的块。

3.函数参数

  • input:Tensor类型,需要被拆分的Tensor
  • chunks:int类型,需要拆分的chunk数量
  • dim:int类型,拆分张量的维度

4.函数返回值

返回拆分好的Tensor列表

1.2 torch.chunk的使用

1.2.1 将torch.chunk应用在二维Tensor上

import torch

if __name__ == '__main__':
    input = torch.arange(0,16).view(4,4)
    chunk_dim0 = torch.chunk(input,4,0)
    chunk_dim1 = torch.chunk(input,4,1)
    print('input:{0} \n'.format(input))
    print('chunk_dim0:{0} \n'.format(chunk_dim0))
    print('chunk_dim1:{0} \n'.format(chunk_dim1))

输出

input:tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]]) 

chunk_dim0:(tensor([[0, 1, 2, 3]]), tensor([[4, 5, 6, 7]]), tensor([[ 8,  9, 10, 11]]), tensor([[12, 13, 14, 15]])) 

chunk_dim1:(tensor([[ 0],
        [ 4],
        [ 8],
        [12]]), tensor([[ 1],
        [ 5],
        [ 9],
        [13]]), tensor([[ 2],
        [ 6],
        [10],
        [14]]), tensor([[ 3],
        [ 7],
        [11],
        [15]]))

1.2.2 将torch.chunk应用在三维Tensor上

import torch

if __name__ == '__main__':
    input = torch.arange(0,64).view(4,4,4)
    chunk_dim0 = torch.chunk(input,4,0)
    chunk_dim1 = torch.chunk(input,4,1)
    chunk_dim2 = torch.chunk(input, 4, 2)
    print('input:{0} \n'.format(input))
    print('chunk_dim0:{0} \n'.format(chunk_dim0))
    print('chunk_dim1:{0} \n'.format(chunk_dim1))
    print('chunk_dim2:{0} \n'.format(chunk_dim2))

输出

input:tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]],

        [[32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43],
         [44, 45, 46, 47]],

        [[48, 49, 50, 51],
         [52, 53, 54, 55],
         [56, 57, 58, 59],
         [60, 61, 62, 63]]]) 

chunk_dim0:(tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]]]), tensor([[[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]]), tensor([[[32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43],
         [44, 45, 46, 47]]]), tensor([[[48, 49, 50, 51],
         [52, 53, 54, 55],
         [56, 57, 58, 59],
         [60, 61, 62, 63]]])) 

chunk_dim1:(tensor([[[ 0,  1,  2,  3]],

        [[16, 17, 18, 19]],

        [[32, 33, 34, 35]],

        [[48, 49, 50, 51]]]), tensor([[[ 4,  5,  6,  7]],

        [[20, 21, 22, 23]],

        [[36, 37, 38, 39]],

        [[52, 53, 54, 55]]]), tensor([[[ 8,  9, 10, 11]],

        [[24, 25, 26, 27]],

        [[40, 41, 42, 43]],

        [[56, 57, 58, 59]]]), tensor([[[12, 13, 14, 15]],

        [[28, 29, 30, 31]],

        [[44, 45, 46, 47]],

        [[60, 61, 62, 63]]])) 

chunk_dim2:(tensor([[[ 0],
         [ 4],
         [ 8],
         [12]],

        [[16],
         [20],
         [24],
         [28]],

        [[32],
         [36],
         [40],
         [44]],

        [[48],
         [52],
         [56],
         [60]]]), tensor([[[ 1],
         [ 5],
         [ 9],
         [13]],

        [[17],
         [21],
         [25],
         [29]],

        [[33],
         [37],
         [41],
         [45]],

        [[49],
         [53],
         [57],
         [61]]]), tensor([[[ 2],
         [ 6],
         [10],
         [14]],

        [[18],
         [22],
         [26],
         [30]],

        [[34],
         [38],
         [42],
         [46]],

        [[50],
         [54],
         [58],
         [62]]]), tensor([[[ 3],
         [ 7],
         [11],
         [15]],

        [[19],
         [23],
         [27],
         [31]],

        [[35],
         [39],
         [43],
         [47]],

        [[51],
         [55],
         [59],
         [63]]]))