Pytorch – 使用torch.matmul()替换torch.einsum(‘nctw,cd->ndtw’,(a,b))算子模式
1 pytorch的torch.matmul()函数
函数形式
torch.matmul(input, other, *, out=None) → Tensor
该函数主要是用于求解两个tensor的矩阵乘积。
该函数根据输入的两个tensor的维度的不同进行不一样的张量运算,如下所示
- 如果两个参数张量Tensor都是一维的,那么该函数返回两个张量的点积(标量)
- 如果两个参数张量Tensor都是二维的,那么该函数返回二维矩阵与二维矩阵的乘积
- 如果第一个参数张量Tensor是一维的,第二个张量Tensor是二维的,则先扩充第一个张量Tensor的维数,维数+1,然后进行二维矩阵乘法,在得到矩阵乘法结果后,移除结果张量的前置维度
- 如果第一个参数张量Tensor是二维的,第二个张量Tensor是一维的,则返回矩阵向量积
- 如果两个参数张量Tensor至少为一维且至少有一个参数N维(其中N>2),则返回batch矩阵乘法。如果第一个参数是一维的,则将 1 添加到其维度,以便批量矩阵相乘并在之后删除。如果第二个参数是一维的,则将 1 附加到其维度以用于批量矩阵倍数并在之后删除。非矩阵(即批量)维度是广播的(因此必须是可广播的)。例如,如果第一个参数张量Tensor为\left ( j\times 1\times n \times n \right )维张量,第二个参数张量Tensor为\left ( k\times n\times n \right )维张量,则结果为\left ( j\times k\times n\times n \right )维张量。
这里需要注意的是,在使用matmul函数在确定输入是否可以广播的时候,是通过两个输入张量Tensor的后两个维度是否满足矩阵相乘准则来判断的。比如说第一个输入张量维度为\left ( 100 \times 50 \times 3 \times 6 \right ),第二个输入张量的维度为\left ( 50 \times 6 \times 3 \right ),第一个张量的最后两个维度3 \times 6与第二个张量的最后两个维度6 \times 3可以根据二维矩阵乘法得到3 \times 3矩阵说明是可以广播,上面两个张量的计算的结果维度为\left ( 100 \times 50 \times 3 \times 3 \right )。
2 使用torch.matmul()替换torch.einsum('nctw,cd->ndtw',(a,b))模式
torch.einsum('nctw,cd->ndtw',(a,b))
的意思是将第一个维度为nctw
的张量与第二个维度为cd
的张量进行batch矩阵相乘得到结果张量,结果张量维度为ndtw
。
所以这个einsum模式是可以通过torch.matmul()函数进行改写的。
(1)资源收集自互联网,仅供自我学习,请在下载后24小时内删除该资源,如下载者将此资源用于其他非法用途,本站不承担任何法律责任;如有侵权,请立即联系我,马上删除!
(2)下载单个资源则点击立即下载或者立即购买按钮;本站VIP可下载本站所有资源。
(3)请不要使用手机以及电脑浏览器的无痕模式进行支付操作,以免造成支付成功但未显示下载链接。
(4)如遇支付问题或者资源失效问题请点击按钮点击反馈进行反馈或者发送说明邮件到stubbornhuang@qq.com
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:Pytorch – 使用torch.matmul()替换torch.einsum(‘nctw,cd->ndtw’,(a,b))算子模式
原文链接:https://www.stubbornhuang.com/2065/
发布于:2022年03月29日 15:44:42
修改于:2023年06月26日 20:22:54
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50