TensorRT – 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法
1 问题:TensorRT暂时未实现einsum算子
在ST-GCN中使用了爱因斯坦求和算子torch.einsum,
def forward(self, x, A):
assert A.size(0) == self.kernel_size
x = self.conv(x)
n, kc, t, v = x.size()
x = x.view(n, self.kernel_size, kc//self.kernel_size, t, v)
x = torch.einsum('nkctv,kvw->nctw', (x, A))
return x.contiguous(), A
einsum爱因斯坦求和约定确实易用,但是即使是目前TensorRT最新版本TensorRT8也还未支持einsum算子,如果要使用TensorRT上部署ST-GCN网络,则必须增加对einsum算子的支持,
- 第一种方法,则是使用TensorRT自定义插件的API,自己编写代码扩展插件,增加对einsum的支持;
- 第二种方法,则是在pytorch中模型网络中使用常规的pytorch算子替代torch.einsum算子,使得onnx模型转换为TensorRT模型的过程中不会出现einsum找不到的问题;
第一种方法相比第二种方法来说难度较大,需要对TensorRT创建自定义插件的流程以及相关API比较熟悉,如果第二种方法可行将在很大程度降低在TensorRT C++层编程的工作量,直接在python层转换onnx模型完成全部的工作。
2 使用普通算子替代einsum操作的示例
2.1 替换原理
关于对torch.einsum算子的替换原理,这里就以torch.einsum('nkctv,kvw->nctw', (a, b))进行说明,如果对einsum不了解,可以先查看这篇文章:https://zhuanlan.zhihu.com/p/361209187,
torch.einsum('nkctv,kvw->nctw', (a, b))
算子的意思是两个输入tensor:
- a :tensor(n,k,c,t,v)
- b:tensor(k,v,w)
两者通过batch 矩阵乘法得到输出向量tensor(n,c,t,w)。
在整个运算过程中总共有n、k、c、t、v、w六个维度参与运算,而最后的输出结果只有n、c、t、w四个维度,说明需要求和的维度是k、v,如果我们在最后的求和过程中将最后的结果先变成所有维度的tensor即tensor(n,k,c,t,v,w),然后再在k、v两个维度求和,那么tensor不就变成所需要的tensor(n,c,t,w)了?!
那么如何将输出tensor扩展为tensor(n,k,c,t,v,w)?
首先将a reshape 为tensor(n,k,c,t,v,1),然后将b reshape为tensor(1,k,1,1,v,w),然后将a*b,那么最后的tensor就变成了tensor(n,k,c,t,v,w)了,之后只需将k、v两个维度求和得出最后的结果tensor(n,c,t,w)
这种替换之后的计算结果是否正确?
根据之后的示例计算的结果比较上来看,结果没有什么问题。
这种方法经测试,不管在训练时还是推理都很慢,并且会消耗大量显存,造成显存溢出的问题,新的替换方式请参考:https://www.stubbornhuang.com/2065/
2.2 转换示例
2.2.1 对torch.einsum("nctw,cd->ndtw",(a,b))的替代
import torch
if __name__ == '__main__':
n_dim = 2
c_dim = 3
t_dim = 4
w_dim = 5
d_dim = 6
a = torch.rand(n_dim,c_dim,t_dim,w_dim).cuda()
b = torch.rand(c_dim,d_dim).cuda()
# 1 使用einsum算子
a_b_einsum = torch.einsum("nctw,cd->ndtw",(a,b))
print(a_b_einsum.shape)
print(a_b_einsum)
# 2 替代方法
d = a.reshape(n_dim,c_dim,t_dim,w_dim,1)
e = b.reshape(1,c_dim,1,1,d_dim)
d_e = d * e
g = torch.sum(d_e,dim=1)
g = g.transpose(1, 3).transpose(2,3).contiguous()
print(g.shape)
print(g)
运行结果:
torch.Size([2, 6, 4, 5])
tensor([[[[0.1536, 0.8770, 1.5611, 1.1315, 0.4630],
[1.2332, 0.6786, 0.7968, 0.8069, 1.1939],
[0.9592, 1.7275, 1.8042, 0.9783, 1.6064],
[0.6743, 0.3755, 1.6635, 1.0484, 1.1097]],
[[0.0620, 0.8995, 0.8229, 0.4608, 0.2584],
[0.1969, 0.2989, 0.1415, 0.6890, 0.1376],
[0.8199, 0.8984, 0.7805, 0.3741, 0.8887],
[0.0772, 0.1881, 0.7069, 0.8372, 0.8517]],
[[0.1426, 0.9370, 1.5492, 0.9897, 0.4912],
[1.2442, 0.6108, 0.7671, 0.8894, 1.1784],
[0.9610, 1.7029, 1.9283, 1.0882, 1.7516],
[0.8288, 0.3555, 1.7507, 1.1991, 1.1621]],
[[0.0770, 0.7392, 0.8970, 0.5596, 0.2783],
[0.4787, 0.3499, 0.3109, 0.6130, 0.4321],
[0.7142, 0.9846, 0.9704, 0.5076, 0.9691],
[0.2692, 0.2081, 0.8840, 0.7735, 0.7843]],
[[0.0945, 0.6915, 1.1119, 0.5910, 0.3830],
[0.9637, 0.3817, 0.5576, 0.6986, 0.8936],
[0.6641, 1.2121, 1.5420, 0.9152, 1.4114],
[0.7867, 0.2390, 1.3761, 0.9892, 0.8660]],
[[0.1512, 0.5476, 1.4122, 1.1246, 0.4043],
[1.3506, 0.6587, 0.8732, 0.5711, 1.3413],
[0.6997, 1.5718, 1.6919, 0.9340, 1.3964],
[0.7219, 0.3487, 1.5727, 0.7631, 0.8495]]],
[[[1.6477, 0.9142, 0.7481, 1.4262, 1.1146],
[1.3980, 1.6822, 0.8348, 0.6170, 0.5092],
[1.1900, 1.7543, 0.4870, 1.6122, 0.4977],
[1.0047, 1.7500, 1.1585, 0.9877, 1.3967]],
[[0.5425, 0.1120, 0.2557, 0.9042, 0.6450],
[1.0389, 1.0299, 0.6792, 0.1813, 0.0616],
[0.6002, 0.9196, 0.1439, 0.8148, 0.3918],
[0.4817, 0.7592, 0.2160, 0.5903, 0.3621]],
[[1.7228, 0.9820, 0.8371, 1.7044, 1.0725],
[1.4199, 1.9031, 0.8216, 0.5613, 0.4442],
[1.4446, 1.9075, 0.6006, 1.7294, 0.4762],
[1.0802, 1.8771, 1.2108, 0.9614, 1.4954]],
[[0.7927, 0.3467, 0.3733, 0.9429, 0.6648],
[0.9637, 1.0766, 0.6016, 0.2726, 0.1767],
[0.7109, 1.0313, 0.2401, 0.9286, 0.3498],
[0.5656, 0.9434, 0.4722, 0.6004, 0.6271]],
[[1.3611, 0.8212, 0.7117, 1.4981, 0.7295],
[1.0185, 1.5859, 0.5563, 0.3674, 0.2800],
[1.3060, 1.5378, 0.5619, 1.3758, 0.3097],
[0.8645, 1.5070, 0.9745, 0.6634, 1.2249]],
[[1.6435, 1.0060, 0.7292, 1.1459, 0.9853],
[1.0979, 1.3947, 0.6303, 0.6454, 0.5835],
[1.0327, 1.5530, 0.4744, 1.4485, 0.3857],
[0.9158, 1.6384, 1.2431, 0.8607, 1.4372]]]], device='cuda:0')
torch.Size([2, 6, 4, 5])
tensor([[[[0.1536, 0.8770, 1.5611, 1.1315, 0.4630],
[1.2332, 0.6786, 0.7968, 0.8069, 1.1939],
[0.9592, 1.7275, 1.8042, 0.9783, 1.6064],
[0.6743, 0.3755, 1.6635, 1.0484, 1.1097]],
[[0.0620, 0.8995, 0.8229, 0.4608, 0.2584],
[0.1969, 0.2989, 0.1415, 0.6890, 0.1376],
[0.8199, 0.8984, 0.7805, 0.3741, 0.8887],
[0.0772, 0.1881, 0.7069, 0.8372, 0.8517]],
[[0.1426, 0.9370, 1.5492, 0.9897, 0.4912],
[1.2442, 0.6108, 0.7671, 0.8894, 1.1784],
[0.9610, 1.7029, 1.9283, 1.0882, 1.7516],
[0.8288, 0.3555, 1.7507, 1.1991, 1.1621]],
[[0.0770, 0.7392, 0.8970, 0.5596, 0.2783],
[0.4787, 0.3499, 0.3109, 0.6130, 0.4321],
[0.7142, 0.9846, 0.9704, 0.5076, 0.9691],
[0.2692, 0.2081, 0.8840, 0.7735, 0.7843]],
[[0.0945, 0.6915, 1.1119, 0.5910, 0.3830],
[0.9637, 0.3817, 0.5576, 0.6986, 0.8936],
[0.6641, 1.2121, 1.5420, 0.9152, 1.4114],
[0.7867, 0.2390, 1.3761, 0.9892, 0.8660]],
[[0.1512, 0.5476, 1.4122, 1.1246, 0.4043],
[1.3506, 0.6587, 0.8732, 0.5711, 1.3413],
[0.6997, 1.5718, 1.6919, 0.9340, 1.3964],
[0.7219, 0.3487, 1.5727, 0.7631, 0.8495]]],
[[[1.6477, 0.9142, 0.7481, 1.4262, 1.1146],
[1.3980, 1.6822, 0.8348, 0.6170, 0.5092],
[1.1900, 1.7543, 0.4870, 1.6122, 0.4977],
[1.0047, 1.7500, 1.1585, 0.9877, 1.3967]],
[[0.5425, 0.1120, 0.2557, 0.9042, 0.6450],
[1.0389, 1.0299, 0.6792, 0.1813, 0.0616],
[0.6002, 0.9196, 0.1439, 0.8148, 0.3918],
[0.4817, 0.7592, 0.2160, 0.5903, 0.3621]],
[[1.7228, 0.9820, 0.8371, 1.7044, 1.0725],
[1.4199, 1.9031, 0.8216, 0.5613, 0.4442],
[1.4446, 1.9075, 0.6006, 1.7294, 0.4762],
[1.0802, 1.8771, 1.2108, 0.9614, 1.4954]],
[[0.7927, 0.3467, 0.3733, 0.9429, 0.6648],
[0.9637, 1.0766, 0.6016, 0.2726, 0.1767],
[0.7109, 1.0313, 0.2401, 0.9286, 0.3498],
[0.5656, 0.9434, 0.4722, 0.6004, 0.6271]],
[[1.3611, 0.8212, 0.7117, 1.4981, 0.7295],
[1.0185, 1.5859, 0.5563, 0.3674, 0.2800],
[1.3060, 1.5378, 0.5619, 1.3758, 0.3097],
[0.8645, 1.5070, 0.9745, 0.6634, 1.2249]],
[[1.6435, 1.0060, 0.7292, 1.1459, 0.9853],
[1.0979, 1.3947, 0.6303, 0.6454, 0.5835],
[1.0327, 1.5530, 0.4744, 1.4485, 0.3857],
[0.9158, 1.6384, 1.2431, 0.8607, 1.4372]]]], device='cuda:0')
2.2.2 对torch.einsum('nkctv,kvw->nctw',(a,b))的替代
import torch
if __name__ == '__main__':
n_dim = 2
k_dim = 3
c_dim = 4
t_dim = 5
v_dim = 6
w_dim = 7
a = torch.rand(n_dim,k_dim,c_dim,t_dim,v_dim).cuda()
b = torch.rand(k_dim,v_dim,w_dim).cuda()
# 1 使用einsum算子
a_b_einsum = torch.einsum("nkctv,kvw->nctw",(a,b))
print(a_b_einsum.shape)
print(a_b_einsum)
# 2 替代方法
d = a.reshape(n_dim,k_dim,c_dim,t_dim,v_dim,1)
e = b.reshape(1,k_dim,1,1,v_dim,w_dim)
d_e = d * e
g = d_e.sum(dim=4)
g = g.sum(dim=1)
print(g.shape)
print(g)
运行结果:
torch.Size([2, 4, 5, 7])
tensor([[[[5.1243, 6.1886, 6.4867, 5.9503, 4.3884, 5.9156, 6.3786],
[3.5094, 4.5025, 4.6040, 4.4659, 3.2761, 4.0113, 4.6133],
[2.8960, 4.4600, 4.1601, 4.2410, 3.6444, 3.7244, 4.3930],
[3.5727, 5.2637, 5.0044, 4.0297, 3.4084, 3.7042, 5.0993],
[4.2162, 4.6827, 4.8894, 4.5648, 4.2200, 4.3652, 4.9702]],
[[3.8939, 5.8364, 5.4699, 6.0786, 5.3935, 4.7933, 6.8394],
[3.6809, 4.2113, 4.4321, 4.3263, 3.7807, 4.3141, 4.5713],
[3.0575, 4.7094, 4.5717, 4.6121, 3.8757, 4.3018, 5.0006],
[3.2234, 5.1427, 4.6323, 4.9495, 3.8041, 4.5921, 6.3398],
[3.0838, 4.8825, 4.6080, 5.0657, 3.6239, 4.6620, 5.5746]],
[[4.7595, 5.1886, 5.4454, 5.1424, 4.6519, 5.8128, 5.7679],
[2.0943, 3.4811, 3.4878, 3.4047, 2.7593, 2.7664, 4.2419],
[3.1837, 4.4004, 4.3706, 3.9428, 3.6235, 3.9220, 4.5210],
[4.3870, 5.1601, 5.4062, 5.1913, 4.6606, 5.0319, 5.4645],
[2.8502, 3.9418, 3.8537, 4.1178, 3.3862, 3.5893, 4.7189]],
[[4.0604, 4.7589, 5.1743, 5.0615, 3.8071, 4.6072, 5.6293],
[2.4517, 4.2132, 3.8871, 3.6280, 3.2739, 3.5921, 4.2781],
[3.6410, 4.8560, 4.9386, 4.5525, 3.4472, 4.6471, 5.1113],
[3.1657, 4.5200, 4.6471, 4.5107, 3.3468, 4.1749, 5.6870],
[3.1199, 4.6795, 4.2116, 4.7081, 3.6071, 4.9994, 4.9361]]],
[[[4.7273, 4.9349, 5.3361, 4.8414, 4.2015, 5.2810, 5.3009],
[4.6114, 5.0410, 5.8097, 5.3338, 4.8278, 5.4540, 5.9611],
[2.7837, 3.5572, 3.1741, 3.8524, 3.2932, 3.7166, 4.0663],
[4.3297, 5.0760, 5.7087, 4.9175, 5.1967, 5.3271, 6.2677],
[3.7909, 4.4908, 4.7942, 5.0652, 3.3895, 5.2854, 4.9531]],
[[3.3919, 4.7877, 4.8535, 4.6241, 4.1152, 4.1662, 5.2712],
[4.5097, 4.7338, 5.5651, 5.1715, 4.4806, 4.3849, 5.2941],
[3.9918, 5.7766, 5.1338, 6.2193, 4.4749, 5.3787, 6.2975],
[3.5283, 3.8201, 4.4238, 4.1005, 3.0723, 4.5533, 4.0787],
[3.1878, 3.8659, 4.4109, 4.6049, 3.6049, 4.5219, 4.7915]],
[[3.6558, 4.3884, 5.4212, 3.9985, 3.5273, 4.6921, 4.8114],
[4.1333, 5.6260, 5.6888, 4.8582, 4.8365, 4.8057, 5.4795],
[4.5793, 5.6060, 5.6415, 5.9413, 5.2855, 6.0514, 6.9744],
[3.2513, 3.6899, 3.8790, 3.4052, 2.2581, 4.0996, 3.7790],
[4.5139, 4.0690, 5.2260, 5.0382, 3.7201, 4.9167, 5.5515]],
[[3.6879, 4.8078, 5.0622, 4.1290, 3.5639, 4.0953, 4.3765],
[3.4064, 4.1741, 4.6276, 3.8761, 3.8149, 4.1352, 4.8580],
[3.8020, 4.5084, 4.6149, 4.4418, 3.7658, 4.2515, 4.2942],
[3.3575, 3.8007, 4.1480, 4.3891, 3.8744, 4.0676, 4.9271],
[5.1714, 5.1096, 5.9393, 5.7977, 5.3516, 5.7728, 6.0512]]]],
device='cuda:0')
torch.Size([2, 4, 5, 7])
tensor([[[[5.1243, 6.1886, 6.4867, 5.9503, 4.3884, 5.9156, 6.3786],
[3.5094, 4.5025, 4.6040, 4.4659, 3.2761, 4.0113, 4.6133],
[2.8960, 4.4600, 4.1601, 4.2410, 3.6444, 3.7244, 4.3930],
[3.5727, 5.2637, 5.0044, 4.0297, 3.4084, 3.7042, 5.0993],
[4.2162, 4.6827, 4.8894, 4.5648, 4.2200, 4.3652, 4.9702]],
[[3.8939, 5.8364, 5.4699, 6.0786, 5.3935, 4.7933, 6.8394],
[3.6809, 4.2113, 4.4321, 4.3263, 3.7807, 4.3141, 4.5713],
[3.0575, 4.7094, 4.5717, 4.6121, 3.8757, 4.3018, 5.0006],
[3.2234, 5.1427, 4.6323, 4.9495, 3.8041, 4.5921, 6.3398],
[3.0838, 4.8825, 4.6080, 5.0657, 3.6239, 4.6620, 5.5746]],
[[4.7595, 5.1886, 5.4454, 5.1424, 4.6519, 5.8128, 5.7679],
[2.0943, 3.4811, 3.4878, 3.4047, 2.7593, 2.7664, 4.2419],
[3.1837, 4.4004, 4.3706, 3.9428, 3.6235, 3.9220, 4.5210],
[4.3870, 5.1601, 5.4062, 5.1913, 4.6606, 5.0319, 5.4645],
[2.8502, 3.9418, 3.8537, 4.1178, 3.3862, 3.5893, 4.7189]],
[[4.0604, 4.7589, 5.1743, 5.0615, 3.8071, 4.6072, 5.6293],
[2.4517, 4.2132, 3.8871, 3.6280, 3.2739, 3.5921, 4.2781],
[3.6410, 4.8560, 4.9386, 4.5525, 3.4472, 4.6471, 5.1113],
[3.1657, 4.5200, 4.6471, 4.5107, 3.3468, 4.1749, 5.6870],
[3.1199, 4.6795, 4.2116, 4.7081, 3.6071, 4.9994, 4.9361]]],
[[[4.7273, 4.9349, 5.3361, 4.8414, 4.2015, 5.2810, 5.3009],
[4.6114, 5.0410, 5.8097, 5.3338, 4.8278, 5.4540, 5.9611],
[2.7837, 3.5572, 3.1741, 3.8524, 3.2932, 3.7166, 4.0663],
[4.3297, 5.0760, 5.7087, 4.9175, 5.1967, 5.3271, 6.2677],
[3.7909, 4.4908, 4.7942, 5.0652, 3.3895, 5.2854, 4.9531]],
[[3.3919, 4.7877, 4.8535, 4.6241, 4.1152, 4.1662, 5.2712],
[4.5097, 4.7338, 5.5651, 5.1715, 4.4806, 4.3849, 5.2941],
[3.9918, 5.7766, 5.1338, 6.2193, 4.4749, 5.3787, 6.2975],
[3.5283, 3.8201, 4.4238, 4.1005, 3.0723, 4.5533, 4.0787],
[3.1878, 3.8659, 4.4109, 4.6049, 3.6049, 4.5219, 4.7915]],
[[3.6558, 4.3884, 5.4212, 3.9985, 3.5273, 4.6921, 4.8114],
[4.1333, 5.6260, 5.6888, 4.8582, 4.8365, 4.8057, 5.4795],
[4.5793, 5.6060, 5.6415, 5.9413, 5.2855, 6.0514, 6.9744],
[3.2513, 3.6899, 3.8790, 3.4052, 2.2581, 4.0996, 3.7790],
[4.5139, 4.0690, 5.2260, 5.0382, 3.7201, 4.9167, 5.5515]],
[[3.6879, 4.8078, 5.0622, 4.1290, 3.5639, 4.0953, 4.3765],
[3.4064, 4.1741, 4.6276, 3.8761, 3.8149, 4.1352, 4.8580],
[3.8020, 4.5084, 4.6149, 4.4418, 3.7658, 4.2515, 4.2942],
[3.3575, 3.8007, 4.1480, 4.3891, 3.8744, 4.0676, 4.9271],
[5.1714, 5.1096, 5.9393, 5.7977, 5.3516, 5.7728, 6.0512]]]],
device='cuda:0')
2.2.3 对torch.einsum("bhnm,bdhm->bdhn",(a,b))的替代
import torch
if __name__ == '__main__':
b_dim = 2
h_dim = 3
n_dim = 4
m_dim = 5
d_dim = 6
a = torch.rand(b_dim,h_dim,n_dim,m_dim).cuda()
b = torch.rand(b_dim,d_dim,h_dim,m_dim).cuda()
# 1 使用einsum算子
a_b_einsum = torch.einsum("bhnm,bdhm->bdhn",(a,b))
print(a_b_einsum.shape)
print(a_b_einsum)
# 2 替代方法
d = a.reshape(b_dim,1,h_dim,n_dim,m_dim)
e = b.reshape(b_dim,d_dim,h_dim,1,m_dim)
d_e = d * e
g = torch.sum(d_e,dim=-1)
print(g.shape)
print(g)
运行结果:
torch.Size([2, 6, 3, 4])
tensor([[[[0.7485, 1.1304, 1.5609, 0.7659],
[0.4858, 0.6758, 0.7654, 1.5621],
[1.0485, 0.9274, 1.7848, 0.9245]],
[[0.8395, 1.2818, 1.6657, 0.6278],
[0.7742, 0.9645, 1.0017, 2.3575],
[1.3290, 1.0029, 2.0633, 1.4623]],
[[0.5437, 0.8261, 1.6291, 0.8061],
[0.9383, 1.8158, 0.9913, 3.1291],
[1.3296, 1.1673, 2.0923, 1.1132]],
[[0.6604, 0.9283, 1.3009, 0.7243],
[0.6694, 0.8136, 0.6781, 1.4160],
[0.5809, 0.3698, 0.8549, 0.7219]],
[[0.7077, 1.1118, 1.3823, 0.5207],
[0.8915, 1.0434, 1.0218, 2.4020],
[1.0928, 0.7864, 1.8233, 1.3614]],
[[0.8477, 1.2199, 1.2862, 1.1010],
[0.9283, 1.9503, 0.8049, 3.2380],
[1.1121, 0.8291, 1.5418, 1.0480]]],
[[[1.4059, 1.2436, 0.9244, 1.3208],
[0.6139, 0.8953, 1.3918, 0.3312],
[0.3111, 0.7085, 0.8762, 1.3002]],
[[2.0954, 1.4036, 1.4653, 1.7839],
[0.5616, 0.6855, 1.1779, 0.4554],
[1.1214, 0.9399, 1.1625, 1.5147]],
[[1.5334, 1.1275, 1.0018, 1.3577],
[0.7086, 1.0112, 1.6142, 0.5294],
[1.0993, 1.0357, 1.3549, 1.8132]],
[[1.8053, 1.2803, 1.0696, 1.4491],
[0.9126, 1.2431, 1.9852, 0.5952],
[1.3642, 1.8556, 1.7529, 2.6744]],
[[1.4492, 1.1429, 1.0478, 1.2675],
[0.7594, 1.0712, 1.7349, 0.5283],
[1.4567, 1.9925, 1.7371, 2.5965]],
[[1.8359, 1.0199, 0.9771, 1.4224],
[1.5275, 1.6111, 2.1722, 0.8499],
[0.5197, 1.2744, 1.1213, 1.8507]]]], device='cuda:0')
torch.Size([2, 6, 3, 4])
tensor([[[[0.7485, 1.1304, 1.5609, 0.7659],
[0.4858, 0.6758, 0.7654, 1.5621],
[1.0485, 0.9274, 1.7848, 0.9245]],
[[0.8395, 1.2818, 1.6657, 0.6278],
[0.7742, 0.9645, 1.0017, 2.3575],
[1.3290, 1.0029, 2.0633, 1.4623]],
[[0.5437, 0.8261, 1.6291, 0.8061],
[0.9383, 1.8158, 0.9913, 3.1291],
[1.3296, 1.1673, 2.0923, 1.1132]],
[[0.6604, 0.9283, 1.3009, 0.7243],
[0.6694, 0.8136, 0.6781, 1.4160],
[0.5809, 0.3698, 0.8549, 0.7219]],
[[0.7077, 1.1118, 1.3823, 0.5207],
[0.8915, 1.0434, 1.0218, 2.4020],
[1.0928, 0.7864, 1.8233, 1.3614]],
[[0.8477, 1.2199, 1.2862, 1.1010],
[0.9283, 1.9503, 0.8049, 3.2380],
[1.1121, 0.8291, 1.5418, 1.0480]]],
[[[1.4059, 1.2436, 0.9244, 1.3208],
[0.6139, 0.8953, 1.3918, 0.3312],
[0.3111, 0.7085, 0.8762, 1.3002]],
[[2.0954, 1.4036, 1.4653, 1.7839],
[0.5616, 0.6855, 1.1779, 0.4554],
[1.1214, 0.9399, 1.1625, 1.5147]],
[[1.5334, 1.1275, 1.0018, 1.3577],
[0.7086, 1.0112, 1.6142, 0.5294],
[1.0993, 1.0357, 1.3549, 1.8132]],
[[1.8053, 1.2803, 1.0696, 1.4491],
[0.9126, 1.2431, 1.9852, 0.5952],
[1.3642, 1.8556, 1.7529, 2.6744]],
[[1.4492, 1.1429, 1.0478, 1.2675],
[0.7594, 1.0712, 1.7349, 0.5283],
[1.4567, 1.9925, 1.7371, 2.5965]],
[[1.8359, 1.0199, 0.9771, 1.4224],
[1.5275, 1.6111, 2.1722, 0.8499],
[0.5197, 1.2744, 1.1213, 1.8507]]]], device='cuda:0')
2.2.4 对torch.einsum("nkctv,kcvw->nctw",(a,b))的替代
import torch
if __name__ == '__main__':
n_dim = 2
k_dim = 3
c_dim = 4
t_dim = 5
v_dim = 6
w_dim = 7
a = torch.rand(n_dim,k_dim,c_dim,t_dim,v_dim).cuda()
b = torch.rand(k_dim,c_dim,v_dim,w_dim).cuda()
# 1 使用einsum算子
a_b_einsum = torch.einsum("nkctv,kcvw->nctw",(a,b))
print(a_b_einsum.shape)
print(a_b_einsum)
# 2 替代方法
d = a.reshape(n_dim,k_dim,c_dim,t_dim,v_dim,1)
e = b.reshape(1,k_dim,c_dim,1,v_dim,w_dim)
d_e = d * e
g = d_e.sum(dim=4)
g = g.sum(dim=1)
print(g.shape)
print(g)
运行结果:
torch.Size([2, 4, 5, 7])
tensor([[[[3.6886, 4.6802, 3.4045, 4.3325, 3.7815, 4.4445, 3.6630],
[4.3359, 4.6445, 3.5855, 5.3637, 3.4696, 4.9049, 4.8237],
[4.1031, 3.5945, 3.3687, 4.5750, 3.4489, 4.6796, 3.9085],
[4.4883, 5.3206, 3.9503, 5.5086, 4.2464, 5.0378, 4.7638],
[4.7572, 4.4212, 3.6666, 5.2617, 4.8326, 5.6596, 4.1010]],
[[4.2742, 5.7762, 5.3966, 6.5586, 5.5349, 6.4641, 6.2513],
[5.2969, 6.5553, 5.6853, 6.1812, 6.2185, 5.9048, 6.5468],
[4.8239, 6.3793, 6.8062, 6.5025, 6.2583, 5.8821, 6.8121],
[4.7322, 4.2670, 4.6996, 4.9050, 4.1259, 5.2866, 5.4602],
[3.7583, 4.4740, 5.1556, 4.9040, 5.0501, 4.8276, 4.9932]],
[[3.0533, 3.6292, 3.4620, 3.7909, 3.7128, 3.8389, 3.4113],
[3.6842, 4.7901, 3.9888, 4.4786, 4.6155, 4.7465, 3.4218],
[3.4355, 4.0149, 2.6561, 3.3563, 3.6032, 3.8693, 3.2605],
[3.9030, 3.4557, 3.2782, 4.3971, 3.5487, 4.0032, 3.6672],
[4.2805, 4.4385, 4.0839, 3.9391, 4.9937, 4.6652, 3.8699]],
[[3.5206, 3.4749, 4.0478, 3.8961, 3.4374, 4.8120, 4.3148],
[3.5103, 3.8852, 4.5542, 4.8058, 3.5427, 5.1224, 4.1953],
[4.8176, 3.5536, 5.1059, 4.4407, 4.0484, 5.4371, 4.3007],
[5.0861, 4.3964, 5.3681, 5.0272, 5.2155, 5.6755, 4.9088],
[4.1324, 4.3184, 5.0787, 5.3404, 4.9100, 5.7238, 4.6557]]],
[[[3.7440, 3.2831, 2.5068, 3.6752, 3.4652, 4.1801, 3.9745],
[4.3378, 4.7787, 3.3870, 4.7359, 4.1092, 5.0005, 4.0490],
[5.2680, 6.2474, 4.4929, 5.1507, 4.6503, 6.3000, 5.4094],
[4.1845, 3.7848, 3.1942, 4.1688, 3.5709, 4.5376, 3.4854],
[4.4916, 4.4309, 3.4901, 4.9215, 5.0050, 5.9588, 4.2079]],
[[3.8371, 5.5432, 5.2889, 6.0612, 5.0549, 5.5150, 5.5913],
[3.2560, 5.1776, 4.4259, 5.4172, 5.1254, 4.6338, 5.5096],
[4.8119, 4.6175, 4.5121, 5.3523, 4.7609, 4.8680, 5.1802],
[4.6942, 5.6804, 5.6723, 5.9701, 5.1850, 5.4067, 6.0220],
[4.1174, 5.3936, 5.5749, 5.5767, 5.0618, 5.0008, 5.8322]],
[[3.5461, 4.0710, 3.1561, 4.6295, 4.7151, 4.7079, 3.0401],
[3.6591, 3.6458, 3.0113, 4.3259, 3.9257, 4.2912, 3.9183],
[2.0226, 2.6473, 2.6139, 4.0020, 2.9864, 3.2500, 2.7126],
[3.8039, 3.6157, 3.1439, 3.8964, 3.5679, 4.2046, 3.9065],
[3.8460, 4.4008, 3.4570, 4.6452, 4.0701, 4.9926, 3.4568]],
[[4.4745, 3.7711, 5.2040, 5.1922, 5.1194, 5.6031, 4.1346],
[3.6972, 3.9783, 5.1616, 4.4995, 3.9579, 5.4845, 4.6744],
[4.6724, 3.9152, 4.9706, 5.1842, 3.9180, 5.5437, 3.9079],
[3.9847, 3.9112, 4.2006, 4.3341, 3.4161, 4.9568, 4.3607],
[3.5865, 3.5239, 4.2117, 4.1858, 3.7528, 4.5605, 3.4211]]]],
device='cuda:0')
torch.Size([2, 4, 5, 7])
tensor([[[[3.6886, 4.6802, 3.4045, 4.3325, 3.7815, 4.4445, 3.6630],
[4.3359, 4.6445, 3.5855, 5.3637, 3.4696, 4.9049, 4.8237],
[4.1031, 3.5945, 3.3687, 4.5750, 3.4489, 4.6796, 3.9085],
[4.4883, 5.3206, 3.9503, 5.5086, 4.2464, 5.0378, 4.7638],
[4.7572, 4.4212, 3.6666, 5.2617, 4.8326, 5.6596, 4.1010]],
[[4.2742, 5.7762, 5.3966, 6.5586, 5.5349, 6.4641, 6.2513],
[5.2969, 6.5553, 5.6853, 6.1812, 6.2185, 5.9048, 6.5468],
[4.8239, 6.3793, 6.8062, 6.5025, 6.2583, 5.8821, 6.8121],
[4.7322, 4.2670, 4.6996, 4.9050, 4.1259, 5.2866, 5.4602],
[3.7583, 4.4740, 5.1556, 4.9040, 5.0501, 4.8276, 4.9932]],
[[3.0533, 3.6292, 3.4620, 3.7909, 3.7128, 3.8389, 3.4113],
[3.6842, 4.7901, 3.9888, 4.4786, 4.6155, 4.7465, 3.4218],
[3.4355, 4.0149, 2.6561, 3.3563, 3.6032, 3.8693, 3.2605],
[3.9030, 3.4557, 3.2782, 4.3971, 3.5487, 4.0032, 3.6672],
[4.2805, 4.4385, 4.0839, 3.9391, 4.9937, 4.6652, 3.8699]],
[[3.5206, 3.4749, 4.0478, 3.8961, 3.4374, 4.8120, 4.3148],
[3.5103, 3.8852, 4.5542, 4.8058, 3.5427, 5.1224, 4.1953],
[4.8176, 3.5536, 5.1059, 4.4407, 4.0484, 5.4371, 4.3007],
[5.0861, 4.3964, 5.3681, 5.0272, 5.2155, 5.6755, 4.9088],
[4.1324, 4.3184, 5.0787, 5.3404, 4.9100, 5.7238, 4.6557]]],
[[[3.7440, 3.2831, 2.5068, 3.6752, 3.4652, 4.1801, 3.9745],
[4.3378, 4.7787, 3.3870, 4.7359, 4.1092, 5.0005, 4.0490],
[5.2680, 6.2474, 4.4929, 5.1507, 4.6503, 6.3000, 5.4094],
[4.1845, 3.7848, 3.1942, 4.1688, 3.5709, 4.5376, 3.4854],
[4.4916, 4.4309, 3.4901, 4.9215, 5.0050, 5.9588, 4.2079]],
[[3.8371, 5.5432, 5.2889, 6.0612, 5.0549, 5.5150, 5.5913],
[3.2560, 5.1776, 4.4259, 5.4172, 5.1254, 4.6338, 5.5096],
[4.8119, 4.6175, 4.5121, 5.3523, 4.7609, 4.8680, 5.1802],
[4.6942, 5.6804, 5.6723, 5.9701, 5.1850, 5.4067, 6.0220],
[4.1174, 5.3936, 5.5749, 5.5767, 5.0618, 5.0008, 5.8322]],
[[3.5461, 4.0710, 3.1561, 4.6295, 4.7151, 4.7079, 3.0401],
[3.6591, 3.6458, 3.0113, 4.3259, 3.9257, 4.2912, 3.9183],
[2.0226, 2.6473, 2.6139, 4.0020, 2.9864, 3.2500, 2.7126],
[3.8039, 3.6157, 3.1439, 3.8964, 3.5679, 4.2046, 3.9065],
[3.8460, 4.4008, 3.4570, 4.6452, 4.0701, 4.9926, 3.4568]],
[[4.4745, 3.7711, 5.2040, 5.1922, 5.1194, 5.6031, 4.1346],
[3.6972, 3.9783, 5.1616, 4.4995, 3.9579, 5.4845, 4.6744],
[4.6724, 3.9152, 4.9706, 5.1842, 3.9180, 5.5437, 3.9079],
[3.9847, 3.9112, 4.2006, 4.3341, 3.4161, 4.9568, 4.3607],
[3.5865, 3.5239, 4.2117, 4.1858, 3.7528, 4.5605, 3.4211]]]],
device='cuda:0')
3 在转换pytorch->onnx->TensorRT模型中使用该替换方法
如果模型转换的过程中采用的使用pytorch->onnx->TensorRT的转换路线,那么上述算子替换操作需要在pytorch->onnx模型的过程中就完成替换,这样在onnx模型中就没有einsum算子,那么在onnx->TensorRT模型的过程中自然就不会出现找不到einsum的错误了。
参考链接
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:TensorRT – 使用torch普通算子组合替代torch.einsum爱因斯坦求和约定算子的一般性方法
原文链接:https://www.stubbornhuang.com/1741/
发布于:2021年10月08日 16:13:30
修改于:2023年06月26日 21:12:42
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
50