1 问题:TensorRT暂时未实现einsum算子

在ST-GCN中使用了爱因斯坦求和算子torch.einsum

def forward(self, x, A):
        assert A.size(0) == self.kernel_size

        x = self.conv(x)

        n, kc, t, v = x.size()
        x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)
        x = torch.einsum('nkctv,kvw->nctw', (x, A))

        return x.contiguous(), A

einsum爱因斯坦求和约定确实易用,但是即使是目前TensorRT最新版本TensorRT8也还未支持einsum算子,如果要使用TensorRT上部署ST-GCN网络,则必须增加对einsum算子的支持,

  • 第一种方法,则是使用TensorRT自定义插件的API,自己编写代码扩展插件,增加对einsum的支持;
  • 第二种方法,则是在pytorch中模型网络中使用常规的pytorch算子替代torch.einsum算子,使得onnx模型转换为TensorRT模型的过程中不会出现einsum找不到的问题;

第一种方法相比第二种方法来说难度较大,需要对TensorRT创建自定义插件的流程以及相关API比较熟悉,如果第二种方法可行将在很大程度降低在TensorRT C++层编程的工作量,直接在python层转换onnx模型完成全部的工作。

2 使用普通算子替代einsum操作的示例

2.1 替换原理

关于对torch.einsum算子的替换原理,这里就以torch.einsum('nkctv,kvw->nctw', (a, b))进行说明,如果对einsum不了解,可以先查看这篇文章:https://zhuanlan.zhihu.com/p/361209187

torch.einsum('nkctv,kvw->nctw', (a, b))

算子的意思是两个输入tensor:

  • a :tensor(n,k,c,t,v)
  • b:tensor(k,v,w)

两者通过batch 矩阵乘法得到输出向量tensor(n,c,t,w)。

在整个运算过程中总共有n、k、c、t、v、w六个维度参与运算,而最后的输出结果只有n、c、t、w四个维度,说明需要求和的维度是k、v,如果我们在最后的求和过程中将最后的结果先变成所有维度的tensor即tensor(n,k,c,t,v,w),然后再在k、v两个维度求和,那么tensor不就变成所需要的tensor(n,c,t,w)了?!

那么如何将输出tensor扩展为tensor(n,k,c,t,v,w)?

首先将a reshape 为tensor(n,k,c,t,v,1),然后将b reshape为tensor(1,k,1,1,v,w),然后将a*b,那么最后的tensor就变成了tensor(n,k,c,t,v,w)了,之后只需将k、v两个维度求和得出最后的结果tensor(n,c,t,w)

这种替换之后的计算结果是否正确?

根据之后的示例计算的结果比较上来看,结果没有什么问题。

这种方法经测试,不管在训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题,新的替换方式请参考:https://www.stubbornhuang.com/2065/

2.2 转换示例

2.2.1 对torch.einsum("nctw,cd->ndtw",(a,b))的替代

import torch

if __name__ == '__main__':
    n_dim = 2
    c_dim = 3
    t_dim = 4
    w_dim = 5
    d_dim = 6

    a = torch.rand(n_dim,c_dim,t_dim,w_dim).cuda()
    b = torch.rand(c_dim,d_dim).cuda()

    # 1 使用einsum算子
    a_b_einsum = torch.einsum("nctw,cd->ndtw",(a,b))
    print(a_b_einsum.shape)
    print(a_b_einsum)

    # 2 替代方法
    d = a.reshape(n_dim,c_dim,t_dim,w_dim,1)
    e = b.reshape(1,c_dim,1,1,d_dim)
    d_e = d * e
    g = torch.sum(d_e,dim=1)
    g = g.transpose(1, 3).transpose(2,3).contiguous()

    print(g.shape)
    print(g)

运行结果:

torch.Size([2, 6, 4, 5])
tensor([[[[0.1536, 0.8770, 1.5611, 1.1315, 0.4630],
          [1.2332, 0.6786, 0.7968, 0.8069, 1.1939],
          [0.9592, 1.7275, 1.8042, 0.9783, 1.6064],
          [0.6743, 0.3755, 1.6635, 1.0484, 1.1097]],

         [[0.0620, 0.8995, 0.8229, 0.4608, 0.2584],
          [0.1969, 0.2989, 0.1415, 0.6890, 0.1376],
          [0.8199, 0.8984, 0.7805, 0.3741, 0.8887],
          [0.0772, 0.1881, 0.7069, 0.8372, 0.8517]],

         [[0.1426, 0.9370, 1.5492, 0.9897, 0.4912],
          [1.2442, 0.6108, 0.7671, 0.8894, 1.1784],
          [0.9610, 1.7029, 1.9283, 1.0882, 1.7516],
          [0.8288, 0.3555, 1.7507, 1.1991, 1.1621]],

         [[0.0770, 0.7392, 0.8970, 0.5596, 0.2783],
          [0.4787, 0.3499, 0.3109, 0.6130, 0.4321],
          [0.7142, 0.9846, 0.9704, 0.5076, 0.9691],
          [0.2692, 0.2081, 0.8840, 0.7735, 0.7843]],

         [[0.0945, 0.6915, 1.1119, 0.5910, 0.3830],
          [0.9637, 0.3817, 0.5576, 0.6986, 0.8936],
          [0.6641, 1.2121, 1.5420, 0.9152, 1.4114],
          [0.7867, 0.2390, 1.3761, 0.9892, 0.8660]],

         [[0.1512, 0.5476, 1.4122, 1.1246, 0.4043],
          [1.3506, 0.6587, 0.8732, 0.5711, 1.3413],
          [0.6997, 1.5718, 1.6919, 0.9340, 1.3964],
          [0.7219, 0.3487, 1.5727, 0.7631, 0.8495]]],


        [[[1.6477, 0.9142, 0.7481, 1.4262, 1.1146],
          [1.3980, 1.6822, 0.8348, 0.6170, 0.5092],
          [1.1900, 1.7543, 0.4870, 1.6122, 0.4977],
          [1.0047, 1.7500, 1.1585, 0.9877, 1.3967]],

         [[0.5425, 0.1120, 0.2557, 0.9042, 0.6450],
          [1.0389, 1.0299, 0.6792, 0.1813, 0.0616],
          [0.6002, 0.9196, 0.1439, 0.8148, 0.3918],
          [0.4817, 0.7592, 0.2160, 0.5903, 0.3621]],

         [[1.7228, 0.9820, 0.8371, 1.7044, 1.0725],
          [1.4199, 1.9031, 0.8216, 0.5613, 0.4442],
          [1.4446, 1.9075, 0.6006, 1.7294, 0.4762],
          [1.0802, 1.8771, 1.2108, 0.9614, 1.4954]],

         [[0.7927, 0.3467, 0.3733, 0.9429, 0.6648],
          [0.9637, 1.0766, 0.6016, 0.2726, 0.1767],
          [0.7109, 1.0313, 0.2401, 0.9286, 0.3498],
          [0.5656, 0.9434, 0.4722, 0.6004, 0.6271]],

         [[1.3611, 0.8212, 0.7117, 1.4981, 0.7295],
          [1.0185, 1.5859, 0.5563, 0.3674, 0.2800],
          [1.3060, 1.5378, 0.5619, 1.3758, 0.3097],
          [0.8645, 1.5070, 0.9745, 0.6634, 1.2249]],

         [[1.6435, 1.0060, 0.7292, 1.1459, 0.9853],
          [1.0979, 1.3947, 0.6303, 0.6454, 0.5835],
          [1.0327, 1.5530, 0.4744, 1.4485, 0.3857],
          [0.9158, 1.6384, 1.2431, 0.8607, 1.4372]]]], device='cuda:0')
torch.Size([2, 6, 4, 5])
tensor([[[[0.1536, 0.8770, 1.5611, 1.1315, 0.4630],
          [1.2332, 0.6786, 0.7968, 0.8069, 1.1939],
          [0.9592, 1.7275, 1.8042, 0.9783, 1.6064],
          [0.6743, 0.3755, 1.6635, 1.0484, 1.1097]],

         [[0.0620, 0.8995, 0.8229, 0.4608, 0.2584],
          [0.1969, 0.2989, 0.1415, 0.6890, 0.1376],
          [0.8199, 0.8984, 0.7805, 0.3741, 0.8887],
          [0.0772, 0.1881, 0.7069, 0.8372, 0.8517]],

         [[0.1426, 0.9370, 1.5492, 0.9897, 0.4912],
          [1.2442, 0.6108, 0.7671, 0.8894, 1.1784],
          [0.9610, 1.7029, 1.9283, 1.0882, 1.7516],
          [0.8288, 0.3555, 1.7507, 1.1991, 1.1621]],

         [[0.0770, 0.7392, 0.8970, 0.5596, 0.2783],
          [0.4787, 0.3499, 0.3109, 0.6130, 0.4321],
          [0.7142, 0.9846, 0.9704, 0.5076, 0.9691],
          [0.2692, 0.2081, 0.8840, 0.7735, 0.7843]],

         [[0.0945, 0.6915, 1.1119, 0.5910, 0.3830],
          [0.9637, 0.3817, 0.5576, 0.6986, 0.8936],
          [0.6641, 1.2121, 1.5420, 0.9152, 1.4114],
          [0.7867, 0.2390, 1.3761, 0.9892, 0.8660]],

         [[0.1512, 0.5476, 1.4122, 1.1246, 0.4043],
          [1.3506, 0.6587, 0.8732, 0.5711, 1.3413],
          [0.6997, 1.5718, 1.6919, 0.9340, 1.3964],
          [0.7219, 0.3487, 1.5727, 0.7631, 0.8495]]],


        [[[1.6477, 0.9142, 0.7481, 1.4262, 1.1146],
          [1.3980, 1.6822, 0.8348, 0.6170, 0.5092],
          [1.1900, 1.7543, 0.4870, 1.6122, 0.4977],
          [1.0047, 1.7500, 1.1585, 0.9877, 1.3967]],

         [[0.5425, 0.1120, 0.2557, 0.9042, 0.6450],
          [1.0389, 1.0299, 0.6792, 0.1813, 0.0616],
          [0.6002, 0.9196, 0.1439, 0.8148, 0.3918],
          [0.4817, 0.7592, 0.2160, 0.5903, 0.3621]],

         [[1.7228, 0.9820, 0.8371, 1.7044, 1.0725],
          [1.4199, 1.9031, 0.8216, 0.5613, 0.4442],
          [1.4446, 1.9075, 0.6006, 1.7294, 0.4762],
          [1.0802, 1.8771, 1.2108, 0.9614, 1.4954]],

         [[0.7927, 0.3467, 0.3733, 0.9429, 0.6648],
          [0.9637, 1.0766, 0.6016, 0.2726, 0.1767],
          [0.7109, 1.0313, 0.2401, 0.9286, 0.3498],
          [0.5656, 0.9434, 0.4722, 0.6004, 0.6271]],

         [[1.3611, 0.8212, 0.7117, 1.4981, 0.7295],
          [1.0185, 1.5859, 0.5563, 0.3674, 0.2800],
          [1.3060, 1.5378, 0.5619, 1.3758, 0.3097],
          [0.8645, 1.5070, 0.9745, 0.6634, 1.2249]],

         [[1.6435, 1.0060, 0.7292, 1.1459, 0.9853],
          [1.0979, 1.3947, 0.6303, 0.6454, 0.5835],
          [1.0327, 1.5530, 0.4744, 1.4485, 0.3857],
          [0.9158, 1.6384, 1.2431, 0.8607, 1.4372]]]], device='cuda:0')

2.2.2 对torch.einsum('nkctv,kvw->nctw',(a,b))的替代

import torch

if __name__ == '__main__':
    n_dim = 2
    k_dim = 3
    c_dim = 4
    t_dim = 5
    v_dim = 6
    w_dim = 7

    a = torch.rand(n_dim,k_dim,c_dim,t_dim,v_dim).cuda()
    b = torch.rand(k_dim,v_dim,w_dim).cuda()

    # 1 使用einsum算子
    a_b_einsum = torch.einsum("nkctv,kvw->nctw",(a,b))
    print(a_b_einsum.shape)
    print(a_b_einsum)

    # 2 替代方法
    d = a.reshape(n_dim,k_dim,c_dim,t_dim,v_dim,1)
    e = b.reshape(1,k_dim,1,1,v_dim,w_dim)
    d_e = d * e
    g = d_e.sum(dim=4)
    g = g.sum(dim=1)

    print(g.shape)
    print(g)

运行结果:

torch.Size([2, 4, 5, 7])
tensor([[[[5.1243, 6.1886, 6.4867, 5.9503, 4.3884, 5.9156, 6.3786],
          [3.5094, 4.5025, 4.6040, 4.4659, 3.2761, 4.0113, 4.6133],
          [2.8960, 4.4600, 4.1601, 4.2410, 3.6444, 3.7244, 4.3930],
          [3.5727, 5.2637, 5.0044, 4.0297, 3.4084, 3.7042, 5.0993],
          [4.2162, 4.6827, 4.8894, 4.5648, 4.2200, 4.3652, 4.9702]],

         [[3.8939, 5.8364, 5.4699, 6.0786, 5.3935, 4.7933, 6.8394],
          [3.6809, 4.2113, 4.4321, 4.3263, 3.7807, 4.3141, 4.5713],
          [3.0575, 4.7094, 4.5717, 4.6121, 3.8757, 4.3018, 5.0006],
          [3.2234, 5.1427, 4.6323, 4.9495, 3.8041, 4.5921, 6.3398],
          [3.0838, 4.8825, 4.6080, 5.0657, 3.6239, 4.6620, 5.5746]],

         [[4.7595, 5.1886, 5.4454, 5.1424, 4.6519, 5.8128, 5.7679],
          [2.0943, 3.4811, 3.4878, 3.4047, 2.7593, 2.7664, 4.2419],
          [3.1837, 4.4004, 4.3706, 3.9428, 3.6235, 3.9220, 4.5210],
          [4.3870, 5.1601, 5.4062, 5.1913, 4.6606, 5.0319, 5.4645],
          [2.8502, 3.9418, 3.8537, 4.1178, 3.3862, 3.5893, 4.7189]],

         [[4.0604, 4.7589, 5.1743, 5.0615, 3.8071, 4.6072, 5.6293],
          [2.4517, 4.2132, 3.8871, 3.6280, 3.2739, 3.5921, 4.2781],
          [3.6410, 4.8560, 4.9386, 4.5525, 3.4472, 4.6471, 5.1113],
          [3.1657, 4.5200, 4.6471, 4.5107, 3.3468, 4.1749, 5.6870],
          [3.1199, 4.6795, 4.2116, 4.7081, 3.6071, 4.9994, 4.9361]]],


        [[[4.7273, 4.9349, 5.3361, 4.8414, 4.2015, 5.2810, 5.3009],
          [4.6114, 5.0410, 5.8097, 5.3338, 4.8278, 5.4540, 5.9611],
          [2.7837, 3.5572, 3.1741, 3.8524, 3.2932, 3.7166, 4.0663],
          [4.3297, 5.0760, 5.7087, 4.9175, 5.1967, 5.3271, 6.2677],
          [3.7909, 4.4908, 4.7942, 5.0652, 3.3895, 5.2854, 4.9531]],

         [[3.3919, 4.7877, 4.8535, 4.6241, 4.1152, 4.1662, 5.2712],
          [4.5097, 4.7338, 5.5651, 5.1715, 4.4806, 4.3849, 5.2941],
          [3.9918, 5.7766, 5.1338, 6.2193, 4.4749, 5.3787, 6.2975],
          [3.5283, 3.8201, 4.4238, 4.1005, 3.0723, 4.5533, 4.0787],
          [3.1878, 3.8659, 4.4109, 4.6049, 3.6049, 4.5219, 4.7915]],

         [[3.6558, 4.3884, 5.4212, 3.9985, 3.5273, 4.6921, 4.8114],
          [4.1333, 5.6260, 5.6888, 4.8582, 4.8365, 4.8057, 5.4795],
          [4.5793, 5.6060, 5.6415, 5.9413, 5.2855, 6.0514, 6.9744],
          [3.2513, 3.6899, 3.8790, 3.4052, 2.2581, 4.0996, 3.7790],
          [4.5139, 4.0690, 5.2260, 5.0382, 3.7201, 4.9167, 5.5515]],

         [[3.6879, 4.8078, 5.0622, 4.1290, 3.5639, 4.0953, 4.3765],
          [3.4064, 4.1741, 4.6276, 3.8761, 3.8149, 4.1352, 4.8580],
          [3.8020, 4.5084, 4.6149, 4.4418, 3.7658, 4.2515, 4.2942],
          [3.3575, 3.8007, 4.1480, 4.3891, 3.8744, 4.0676, 4.9271],
          [5.1714, 5.1096, 5.9393, 5.7977, 5.3516, 5.7728, 6.0512]]]],
       device='cuda:0')
torch.Size([2, 4, 5, 7])
tensor([[[[5.1243, 6.1886, 6.4867, 5.9503, 4.3884, 5.9156, 6.3786],
          [3.5094, 4.5025, 4.6040, 4.4659, 3.2761, 4.0113, 4.6133],
          [2.8960, 4.4600, 4.1601, 4.2410, 3.6444, 3.7244, 4.3930],
          [3.5727, 5.2637, 5.0044, 4.0297, 3.4084, 3.7042, 5.0993],
          [4.2162, 4.6827, 4.8894, 4.5648, 4.2200, 4.3652, 4.9702]],

         [[3.8939, 5.8364, 5.4699, 6.0786, 5.3935, 4.7933, 6.8394],
          [3.6809, 4.2113, 4.4321, 4.3263, 3.7807, 4.3141, 4.5713],
          [3.0575, 4.7094, 4.5717, 4.6121, 3.8757, 4.3018, 5.0006],
          [3.2234, 5.1427, 4.6323, 4.9495, 3.8041, 4.5921, 6.3398],
          [3.0838, 4.8825, 4.6080, 5.0657, 3.6239, 4.6620, 5.5746]],

         [[4.7595, 5.1886, 5.4454, 5.1424, 4.6519, 5.8128, 5.7679],
          [2.0943, 3.4811, 3.4878, 3.4047, 2.7593, 2.7664, 4.2419],
          [3.1837, 4.4004, 4.3706, 3.9428, 3.6235, 3.9220, 4.5210],
          [4.3870, 5.1601, 5.4062, 5.1913, 4.6606, 5.0319, 5.4645],
          [2.8502, 3.9418, 3.8537, 4.1178, 3.3862, 3.5893, 4.7189]],

         [[4.0604, 4.7589, 5.1743, 5.0615, 3.8071, 4.6072, 5.6293],
          [2.4517, 4.2132, 3.8871, 3.6280, 3.2739, 3.5921, 4.2781],
          [3.6410, 4.8560, 4.9386, 4.5525, 3.4472, 4.6471, 5.1113],
          [3.1657, 4.5200, 4.6471, 4.5107, 3.3468, 4.1749, 5.6870],
          [3.1199, 4.6795, 4.2116, 4.7081, 3.6071, 4.9994, 4.9361]]],


        [[[4.7273, 4.9349, 5.3361, 4.8414, 4.2015, 5.2810, 5.3009],
          [4.6114, 5.0410, 5.8097, 5.3338, 4.8278, 5.4540, 5.9611],
          [2.7837, 3.5572, 3.1741, 3.8524, 3.2932, 3.7166, 4.0663],
          [4.3297, 5.0760, 5.7087, 4.9175, 5.1967, 5.3271, 6.2677],
          [3.7909, 4.4908, 4.7942, 5.0652, 3.3895, 5.2854, 4.9531]],

         [[3.3919, 4.7877, 4.8535, 4.6241, 4.1152, 4.1662, 5.2712],
          [4.5097, 4.7338, 5.5651, 5.1715, 4.4806, 4.3849, 5.2941],
          [3.9918, 5.7766, 5.1338, 6.2193, 4.4749, 5.3787, 6.2975],
          [3.5283, 3.8201, 4.4238, 4.1005, 3.0723, 4.5533, 4.0787],
          [3.1878, 3.8659, 4.4109, 4.6049, 3.6049, 4.5219, 4.7915]],

         [[3.6558, 4.3884, 5.4212, 3.9985, 3.5273, 4.6921, 4.8114],
          [4.1333, 5.6260, 5.6888, 4.8582, 4.8365, 4.8057, 5.4795],
          [4.5793, 5.6060, 5.6415, 5.9413, 5.2855, 6.0514, 6.9744],
          [3.2513, 3.6899, 3.8790, 3.4052, 2.2581, 4.0996, 3.7790],
          [4.5139, 4.0690, 5.2260, 5.0382, 3.7201, 4.9167, 5.5515]],

         [[3.6879, 4.8078, 5.0622, 4.1290, 3.5639, 4.0953, 4.3765],
          [3.4064, 4.1741, 4.6276, 3.8761, 3.8149, 4.1352, 4.8580],
          [3.8020, 4.5084, 4.6149, 4.4418, 3.7658, 4.2515, 4.2942],
          [3.3575, 3.8007, 4.1480, 4.3891, 3.8744, 4.0676, 4.9271],
          [5.1714, 5.1096, 5.9393, 5.7977, 5.3516, 5.7728, 6.0512]]]],
       device='cuda:0')

2.2.3 对torch.einsum("bhnm,bdhm->bdhn",(a,b))的替代

import torch

if __name__ == '__main__':
    b_dim = 2
    h_dim = 3
    n_dim = 4
    m_dim = 5
    d_dim = 6

    a = torch.rand(b_dim,h_dim,n_dim,m_dim).cuda()
    b = torch.rand(b_dim,d_dim,h_dim,m_dim).cuda()

    # 1 使用einsum算子
    a_b_einsum = torch.einsum("bhnm,bdhm->bdhn",(a,b))
    print(a_b_einsum.shape)
    print(a_b_einsum)

    # 2 替代方法
    d = a.reshape(b_dim,1,h_dim,n_dim,m_dim)
    e = b.reshape(b_dim,d_dim,h_dim,1,m_dim)
    d_e = d * e
    g = torch.sum(d_e,dim=-1)

    print(g.shape)
    print(g)

运行结果:

torch.Size([2, 6, 3, 4])
tensor([[[[0.7485, 1.1304, 1.5609, 0.7659],
          [0.4858, 0.6758, 0.7654, 1.5621],
          [1.0485, 0.9274, 1.7848, 0.9245]],

         [[0.8395, 1.2818, 1.6657, 0.6278],
          [0.7742, 0.9645, 1.0017, 2.3575],
          [1.3290, 1.0029, 2.0633, 1.4623]],

         [[0.5437, 0.8261, 1.6291, 0.8061],
          [0.9383, 1.8158, 0.9913, 3.1291],
          [1.3296, 1.1673, 2.0923, 1.1132]],

         [[0.6604, 0.9283, 1.3009, 0.7243],
          [0.6694, 0.8136, 0.6781, 1.4160],
          [0.5809, 0.3698, 0.8549, 0.7219]],

         [[0.7077, 1.1118, 1.3823, 0.5207],
          [0.8915, 1.0434, 1.0218, 2.4020],
          [1.0928, 0.7864, 1.8233, 1.3614]],

         [[0.8477, 1.2199, 1.2862, 1.1010],
          [0.9283, 1.9503, 0.8049, 3.2380],
          [1.1121, 0.8291, 1.5418, 1.0480]]],


        [[[1.4059, 1.2436, 0.9244, 1.3208],
          [0.6139, 0.8953, 1.3918, 0.3312],
          [0.3111, 0.7085, 0.8762, 1.3002]],

         [[2.0954, 1.4036, 1.4653, 1.7839],
          [0.5616, 0.6855, 1.1779, 0.4554],
          [1.1214, 0.9399, 1.1625, 1.5147]],

         [[1.5334, 1.1275, 1.0018, 1.3577],
          [0.7086, 1.0112, 1.6142, 0.5294],
          [1.0993, 1.0357, 1.3549, 1.8132]],

         [[1.8053, 1.2803, 1.0696, 1.4491],
          [0.9126, 1.2431, 1.9852, 0.5952],
          [1.3642, 1.8556, 1.7529, 2.6744]],

         [[1.4492, 1.1429, 1.0478, 1.2675],
          [0.7594, 1.0712, 1.7349, 0.5283],
          [1.4567, 1.9925, 1.7371, 2.5965]],

         [[1.8359, 1.0199, 0.9771, 1.4224],
          [1.5275, 1.6111, 2.1722, 0.8499],
          [0.5197, 1.2744, 1.1213, 1.8507]]]], device='cuda:0')
torch.Size([2, 6, 3, 4])
tensor([[[[0.7485, 1.1304, 1.5609, 0.7659],
          [0.4858, 0.6758, 0.7654, 1.5621],
          [1.0485, 0.9274, 1.7848, 0.9245]],

         [[0.8395, 1.2818, 1.6657, 0.6278],
          [0.7742, 0.9645, 1.0017, 2.3575],
          [1.3290, 1.0029, 2.0633, 1.4623]],

         [[0.5437, 0.8261, 1.6291, 0.8061],
          [0.9383, 1.8158, 0.9913, 3.1291],
          [1.3296, 1.1673, 2.0923, 1.1132]],

         [[0.6604, 0.9283, 1.3009, 0.7243],
          [0.6694, 0.8136, 0.6781, 1.4160],
          [0.5809, 0.3698, 0.8549, 0.7219]],

         [[0.7077, 1.1118, 1.3823, 0.5207],
          [0.8915, 1.0434, 1.0218, 2.4020],
          [1.0928, 0.7864, 1.8233, 1.3614]],

         [[0.8477, 1.2199, 1.2862, 1.1010],
          [0.9283, 1.9503, 0.8049, 3.2380],
          [1.1121, 0.8291, 1.5418, 1.0480]]],


        [[[1.4059, 1.2436, 0.9244, 1.3208],
          [0.6139, 0.8953, 1.3918, 0.3312],
          [0.3111, 0.7085, 0.8762, 1.3002]],

         [[2.0954, 1.4036, 1.4653, 1.7839],
          [0.5616, 0.6855, 1.1779, 0.4554],
          [1.1214, 0.9399, 1.1625, 1.5147]],

         [[1.5334, 1.1275, 1.0018, 1.3577],
          [0.7086, 1.0112, 1.6142, 0.5294],
          [1.0993, 1.0357, 1.3549, 1.8132]],

         [[1.8053, 1.2803, 1.0696, 1.4491],
          [0.9126, 1.2431, 1.9852, 0.5952],
          [1.3642, 1.8556, 1.7529, 2.6744]],

         [[1.4492, 1.1429, 1.0478, 1.2675],
          [0.7594, 1.0712, 1.7349, 0.5283],
          [1.4567, 1.9925, 1.7371, 2.5965]],

         [[1.8359, 1.0199, 0.9771, 1.4224],
          [1.5275, 1.6111, 2.1722, 0.8499],
          [0.5197, 1.2744, 1.1213, 1.8507]]]], device='cuda:0')

2.2.4 对torch.einsum("nkctv,kcvw->nctw",(a,b))的替代

import torch

if __name__ == '__main__':
    n_dim = 2
    k_dim = 3
    c_dim = 4
    t_dim = 5
    v_dim = 6
    w_dim = 7

    a = torch.rand(n_dim,k_dim,c_dim,t_dim,v_dim).cuda()
    b = torch.rand(k_dim,c_dim,v_dim,w_dim).cuda()

    # 1 使用einsum算子
    a_b_einsum = torch.einsum("nkctv,kcvw->nctw",(a,b))
    print(a_b_einsum.shape)
    print(a_b_einsum)

    # 2 替代方法
    d = a.reshape(n_dim,k_dim,c_dim,t_dim,v_dim,1)
    e = b.reshape(1,k_dim,c_dim,1,v_dim,w_dim)
    d_e = d * e
    g = d_e.sum(dim=4)
    g = g.sum(dim=1)

    print(g.shape)
    print(g)

运行结果:

torch.Size([2, 4, 5, 7])
tensor([[[[3.6886, 4.6802, 3.4045, 4.3325, 3.7815, 4.4445, 3.6630],
          [4.3359, 4.6445, 3.5855, 5.3637, 3.4696, 4.9049, 4.8237],
          [4.1031, 3.5945, 3.3687, 4.5750, 3.4489, 4.6796, 3.9085],
          [4.4883, 5.3206, 3.9503, 5.5086, 4.2464, 5.0378, 4.7638],
          [4.7572, 4.4212, 3.6666, 5.2617, 4.8326, 5.6596, 4.1010]],

         [[4.2742, 5.7762, 5.3966, 6.5586, 5.5349, 6.4641, 6.2513],
          [5.2969, 6.5553, 5.6853, 6.1812, 6.2185, 5.9048, 6.5468],
          [4.8239, 6.3793, 6.8062, 6.5025, 6.2583, 5.8821, 6.8121],
          [4.7322, 4.2670, 4.6996, 4.9050, 4.1259, 5.2866, 5.4602],
          [3.7583, 4.4740, 5.1556, 4.9040, 5.0501, 4.8276, 4.9932]],

         [[3.0533, 3.6292, 3.4620, 3.7909, 3.7128, 3.8389, 3.4113],
          [3.6842, 4.7901, 3.9888, 4.4786, 4.6155, 4.7465, 3.4218],
          [3.4355, 4.0149, 2.6561, 3.3563, 3.6032, 3.8693, 3.2605],
          [3.9030, 3.4557, 3.2782, 4.3971, 3.5487, 4.0032, 3.6672],
          [4.2805, 4.4385, 4.0839, 3.9391, 4.9937, 4.6652, 3.8699]],

         [[3.5206, 3.4749, 4.0478, 3.8961, 3.4374, 4.8120, 4.3148],
          [3.5103, 3.8852, 4.5542, 4.8058, 3.5427, 5.1224, 4.1953],
          [4.8176, 3.5536, 5.1059, 4.4407, 4.0484, 5.4371, 4.3007],
          [5.0861, 4.3964, 5.3681, 5.0272, 5.2155, 5.6755, 4.9088],
          [4.1324, 4.3184, 5.0787, 5.3404, 4.9100, 5.7238, 4.6557]]],


        [[[3.7440, 3.2831, 2.5068, 3.6752, 3.4652, 4.1801, 3.9745],
          [4.3378, 4.7787, 3.3870, 4.7359, 4.1092, 5.0005, 4.0490],
          [5.2680, 6.2474, 4.4929, 5.1507, 4.6503, 6.3000, 5.4094],
          [4.1845, 3.7848, 3.1942, 4.1688, 3.5709, 4.5376, 3.4854],
          [4.4916, 4.4309, 3.4901, 4.9215, 5.0050, 5.9588, 4.2079]],

         [[3.8371, 5.5432, 5.2889, 6.0612, 5.0549, 5.5150, 5.5913],
          [3.2560, 5.1776, 4.4259, 5.4172, 5.1254, 4.6338, 5.5096],
          [4.8119, 4.6175, 4.5121, 5.3523, 4.7609, 4.8680, 5.1802],
          [4.6942, 5.6804, 5.6723, 5.9701, 5.1850, 5.4067, 6.0220],
          [4.1174, 5.3936, 5.5749, 5.5767, 5.0618, 5.0008, 5.8322]],

         [[3.5461, 4.0710, 3.1561, 4.6295, 4.7151, 4.7079, 3.0401],
          [3.6591, 3.6458, 3.0113, 4.3259, 3.9257, 4.2912, 3.9183],
          [2.0226, 2.6473, 2.6139, 4.0020, 2.9864, 3.2500, 2.7126],
          [3.8039, 3.6157, 3.1439, 3.8964, 3.5679, 4.2046, 3.9065],
          [3.8460, 4.4008, 3.4570, 4.6452, 4.0701, 4.9926, 3.4568]],

         [[4.4745, 3.7711, 5.2040, 5.1922, 5.1194, 5.6031, 4.1346],
          [3.6972, 3.9783, 5.1616, 4.4995, 3.9579, 5.4845, 4.6744],
          [4.6724, 3.9152, 4.9706, 5.1842, 3.9180, 5.5437, 3.9079],
          [3.9847, 3.9112, 4.2006, 4.3341, 3.4161, 4.9568, 4.3607],
          [3.5865, 3.5239, 4.2117, 4.1858, 3.7528, 4.5605, 3.4211]]]],
       device='cuda:0')
torch.Size([2, 4, 5, 7])
tensor([[[[3.6886, 4.6802, 3.4045, 4.3325, 3.7815, 4.4445, 3.6630],
          [4.3359, 4.6445, 3.5855, 5.3637, 3.4696, 4.9049, 4.8237],
          [4.1031, 3.5945, 3.3687, 4.5750, 3.4489, 4.6796, 3.9085],
          [4.4883, 5.3206, 3.9503, 5.5086, 4.2464, 5.0378, 4.7638],
          [4.7572, 4.4212, 3.6666, 5.2617, 4.8326, 5.6596, 4.1010]],

         [[4.2742, 5.7762, 5.3966, 6.5586, 5.5349, 6.4641, 6.2513],
          [5.2969, 6.5553, 5.6853, 6.1812, 6.2185, 5.9048, 6.5468],
          [4.8239, 6.3793, 6.8062, 6.5025, 6.2583, 5.8821, 6.8121],
          [4.7322, 4.2670, 4.6996, 4.9050, 4.1259, 5.2866, 5.4602],
          [3.7583, 4.4740, 5.1556, 4.9040, 5.0501, 4.8276, 4.9932]],

         [[3.0533, 3.6292, 3.4620, 3.7909, 3.7128, 3.8389, 3.4113],
          [3.6842, 4.7901, 3.9888, 4.4786, 4.6155, 4.7465, 3.4218],
          [3.4355, 4.0149, 2.6561, 3.3563, 3.6032, 3.8693, 3.2605],
          [3.9030, 3.4557, 3.2782, 4.3971, 3.5487, 4.0032, 3.6672],
          [4.2805, 4.4385, 4.0839, 3.9391, 4.9937, 4.6652, 3.8699]],

         [[3.5206, 3.4749, 4.0478, 3.8961, 3.4374, 4.8120, 4.3148],
          [3.5103, 3.8852, 4.5542, 4.8058, 3.5427, 5.1224, 4.1953],
          [4.8176, 3.5536, 5.1059, 4.4407, 4.0484, 5.4371, 4.3007],
          [5.0861, 4.3964, 5.3681, 5.0272, 5.2155, 5.6755, 4.9088],
          [4.1324, 4.3184, 5.0787, 5.3404, 4.9100, 5.7238, 4.6557]]],


        [[[3.7440, 3.2831, 2.5068, 3.6752, 3.4652, 4.1801, 3.9745],
          [4.3378, 4.7787, 3.3870, 4.7359, 4.1092, 5.0005, 4.0490],
          [5.2680, 6.2474, 4.4929, 5.1507, 4.6503, 6.3000, 5.4094],
          [4.1845, 3.7848, 3.1942, 4.1688, 3.5709, 4.5376, 3.4854],
          [4.4916, 4.4309, 3.4901, 4.9215, 5.0050, 5.9588, 4.2079]],

         [[3.8371, 5.5432, 5.2889, 6.0612, 5.0549, 5.5150, 5.5913],
          [3.2560, 5.1776, 4.4259, 5.4172, 5.1254, 4.6338, 5.5096],
          [4.8119, 4.6175, 4.5121, 5.3523, 4.7609, 4.8680, 5.1802],
          [4.6942, 5.6804, 5.6723, 5.9701, 5.1850, 5.4067, 6.0220],
          [4.1174, 5.3936, 5.5749, 5.5767, 5.0618, 5.0008, 5.8322]],

         [[3.5461, 4.0710, 3.1561, 4.6295, 4.7151, 4.7079, 3.0401],
          [3.6591, 3.6458, 3.0113, 4.3259, 3.9257, 4.2912, 3.9183],
          [2.0226, 2.6473, 2.6139, 4.0020, 2.9864, 3.2500, 2.7126],
          [3.8039, 3.6157, 3.1439, 3.8964, 3.5679, 4.2046, 3.9065],
          [3.8460, 4.4008, 3.4570, 4.6452, 4.0701, 4.9926, 3.4568]],

         [[4.4745, 3.7711, 5.2040, 5.1922, 5.1194, 5.6031, 4.1346],
          [3.6972, 3.9783, 5.1616, 4.4995, 3.9579, 5.4845, 4.6744],
          [4.6724, 3.9152, 4.9706, 5.1842, 3.9180, 5.5437, 3.9079],
          [3.9847, 3.9112, 4.2006, 4.3341, 3.4161, 4.9568, 4.3607],
          [3.5865, 3.5239, 4.2117, 4.1858, 3.7528, 4.5605, 3.4211]]]],
       device='cuda:0')

3 在转换pytorch->onnx->TensorRT模型中使用该替换方法

如果模型转换的过程中采用的使用pytorch->onnx->TensorRT的转换路线,那么上述算子替换操作需要在pytorch->onnx模型的过程中就完成替换,这样在onnx模型中就没有einsum算子,那么在onnx->TensorRT模型的过程中自然就不会出现找不到einsum的错误了。

参考链接