1 torch.cat函数
形式
torch.cat(tensors, dim=0, *, out=None)
功能
在指定的维度连接给定序列的张量,所有张量必须具有相同的形状(连接维度除外)或者为空。
参数
- tensors:相同形状的张量序列,非空张量必须具有相同形状(连接维度除外)
- dim:张量连接的维度
使用示例
import torch
if __name__ == '__main__':
a = torch.randn(size=(2, 2))
print(a)
b = torch.randn(size=(2, 2))
print(b)
c = torch.cat((a, b), dim=0)
print(c)
d = torch.cat((a, b), dim=1)
print(d)
输出
tensor([[ 0.8793, 0.3727],
[-2.3334, -1.4567]])
tensor([[-1.0906, -1.2683],
[ 0.7161, -0.5843]])
tensor([[ 0.8793, 0.3727],
[-2.3334, -1.4567],
[-1.0906, -1.2683],
[ 0.7161, -0.5843]])
tensor([[ 0.8793, 0.3727, -1.0906, -1.2683],
[-2.3334, -1.4567, 0.7161, -0.5843]])
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – torch.cat函数
原文链接:https://www.stubbornhuang.com/2441/
发布于:2022年12月08日 15:51:48
修改于:2023年06月21日 17:44:48
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50