1 TasedNet轻量化
TASED-Net 是一种用于视频显着性检测的新型全卷积网络架构。主要思想简单但有效:对 3D 视频特征进行空间解码,同时联合聚合所有时间信息。
github:https://github.com/MichiganCOG/TASED-Net
TasedNet官方仓库预训练权重有82M,不适合在移动端和PC端(会显著增加软件体积)进行部署。如果需要在移动端或者PC端部署,则需要对现有网络进行轻量化设计。
2 原有模型架构
原有模型架构代码如下,
import torch
from torch import nn
class TASED_v2(nn.Module):
def __init__(self):
super(TASED_v2, self).__init__()
self.base1 = nn.Sequential(
SepConv3d(3, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1)),
BasicConv3d(64, 64, kernel_size=1, stride=1),
SepConv3d(64, 192, kernel_size=3, stride=1, padding=1),
)
self.maxp2 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
self.maxm2 = nn.MaxPool3d(kernel_size=(4,1,1), stride=(4,1,1), padding=(0,0,0))
self.maxt2 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), return_indices=True)
self.base2 = nn.Sequential(
Mixed_3b(),
Mixed_3c(),
)
self.maxp3 = nn.MaxPool3d(kernel_size=(3,3,3), stride=(2,2,2), padding=(1,1,1))
self.maxm3 = nn.MaxPool3d(kernel_size=(4,1,1), stride=(4,1,1), padding=(0,0,0))
self.maxt3 = nn.MaxPool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1), return_indices=True)
self.base3 = nn.Sequential(
Mixed_4b(),
Mixed_4c(),
Mixed_4d(),
Mixed_4e(),
Mixed_4f(),
)
self.maxt4 = nn.MaxPool3d(kernel_size=(2,1,1), stride=(2,1,1), padding=(0,0,0))
self.maxp4 = nn.MaxPool3d(kernel_size=(1,2,2), stride=(1,2,2), padding=(0,0,0), return_indices=True)
self.base4 = nn.Sequential(
Mixed_5b(),
Mixed_5c(),
)
self.convtsp1 = nn.Sequential(
nn.Conv3d(1024, 1024, kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(1024, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.ConvTranspose3d(1024, 832, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False),
nn.BatchNorm3d(832, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.unpool1 = nn.MaxUnpool3d(kernel_size=(1,2,2), stride=(1,2,2), padding=(0,0,0))
self.convtsp2 = nn.Sequential(
nn.ConvTranspose3d(832, 480, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False),
nn.BatchNorm3d(480, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.unpool2 = nn.MaxUnpool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
self.convtsp3 = nn.Sequential(
nn.ConvTranspose3d(480, 192, kernel_size=(1,3,3), stride=1, padding=(0,1,1), bias=False),
nn.BatchNorm3d(192, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.unpool3 = nn.MaxUnpool3d(kernel_size=(1,3,3), stride=(1,2,2), padding=(0,1,1))
self.convtsp4 = nn.Sequential(
nn.ConvTranspose3d(192, 64, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.Conv3d(64, 64, kernel_size=(2,1,1), stride=(2,1,1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.ConvTranspose3d(64, 4, kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(4, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.Conv3d(4, 4, kernel_size=(2,1,1), stride=(2,1,1), bias=False),
nn.BatchNorm3d(4, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.ConvTranspose3d(4, 4, kernel_size=(1,4,4), stride=(1,2,2), padding=(0,1,1), bias=False),
nn.Conv3d(4, 1, kernel_size=1, stride=1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
y3 = self.base1(x)
y = self.maxp2(y3)
y3 = self.maxm2(y3)
_, i2 = self.maxt2(y3)
y2 = self.base2(y)
y = self.maxp3(y2)
y2 = self.maxm3(y2)
_, i1 = self.maxt3(y2)
y1 = self.base3(y)
y = self.maxt4(y1)
y, i0 = self.maxp4(y)
y0 = self.base4(y)
z = self.convtsp1(y0)
z = self.unpool1(z, i0)
z = self.convtsp2(z)
z = self.unpool2(z, i1, y2.size())
z = self.convtsp3(z)
z = self.unpool3(z, i2, y3.size())
z = self.convtsp4(z)
z = z.view(z.size(0), z.size(3), z.size(4))
return z
class BasicConv3d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(BasicConv3d, self).__init__()
self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class SepConv3d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(SepConv3d, self).__init__()
self.conv_s = nn.Conv3d(in_planes, out_planes, kernel_size=(1,kernel_size,kernel_size), stride=(1,stride,stride), padding=(0,padding,padding), bias=False)
self.bn_s = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
self.relu_s = nn.ReLU()
self.conv_t = nn.Conv3d(out_planes, out_planes, kernel_size=(kernel_size,1,1), stride=(stride,1,1), padding=(padding,0,0), bias=False)
self.bn_t = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
self.relu_t = nn.ReLU()
def forward(self, x):
x = self.conv_s(x)
x = self.bn_s(x)
x = self.relu_s(x)
x = self.conv_t(x)
x = self.bn_t(x)
x = self.relu_t(x)
return x
class Mixed_3b(nn.Module):
def __init__(self):
super(Mixed_3b, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 96, kernel_size=1, stride=1),
SepConv3d(96, 128, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 16, kernel_size=1, stride=1),
SepConv3d(16, 32, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(192, 32, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_3c(nn.Module):
def __init__(self):
super(Mixed_3c, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(256, 128, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(256, 128, kernel_size=1, stride=1),
SepConv3d(128, 192, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(256, 32, kernel_size=1, stride=1),
SepConv3d(32, 96, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(256, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4b(nn.Module):
def __init__(self):
super(Mixed_4b, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(480, 192, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(480, 96, kernel_size=1, stride=1),
SepConv3d(96, 208, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(480, 16, kernel_size=1, stride=1),
SepConv3d(16, 48, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(480, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4c(nn.Module):
def __init__(self):
super(Mixed_4c, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(512, 160, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(512, 112, kernel_size=1, stride=1),
SepConv3d(112, 224, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(512, 24, kernel_size=1, stride=1),
SepConv3d(24, 64, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(512, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4d(nn.Module):
def __init__(self):
super(Mixed_4d, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(512, 128, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(512, 128, kernel_size=1, stride=1),
SepConv3d(128, 256, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(512, 24, kernel_size=1, stride=1),
SepConv3d(24, 64, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(512, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4e(nn.Module):
def __init__(self):
super(Mixed_4e, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(512, 112, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(512, 144, kernel_size=1, stride=1),
SepConv3d(144, 288, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(512, 32, kernel_size=1, stride=1),
SepConv3d(32, 64, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(512, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4f(nn.Module):
def __init__(self):
super(Mixed_4f, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(528, 256, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(528, 160, kernel_size=1, stride=1),
SepConv3d(160, 320, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(528, 32, kernel_size=1, stride=1),
SepConv3d(32, 128, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(528, 128, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_5b(nn.Module):
def __init__(self):
super(Mixed_5b, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(832, 256, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(832, 160, kernel_size=1, stride=1),
SepConv3d(160, 320, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(832, 32, kernel_size=1, stride=1),
SepConv3d(32, 128, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(832, 128, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_5c(nn.Module):
def __init__(self):
super(Mixed_5c, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(832, 384, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(832, 192, kernel_size=1, stride=1),
SepConv3d(192, 384, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(832, 48, kernel_size=1, stride=1),
SepConv3d(48, 128, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3,3,3), stride=1, padding=1),
BasicConv3d(832, 128, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
模型输入为:(1, 3, 32, 224, 384),模型输入为:(1, 3, 32, 224, 384)。
3 剪枝模型进行轻量化
对模型剪枝,主要是对3D卷积层的通道数进行减少,模型代码如下
import torch
from torch import nn
class TASED_v2(nn.Module):
def __init__(self):
super(TASED_v2, self).__init__()
self.base1 = nn.Sequential(
SepConv3d(3, 64, kernel_size=7, stride=2, padding=3),
nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)),
BasicConv3d(64, 64, kernel_size=1, stride=1),
SepConv3d(64, 192, kernel_size=3, stride=1, padding=1),
)
self.maxp2 = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
self.base2 = nn.Sequential(
Mixed_3b(),
Mixed_3c(),
)
self.maxp3 = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
self.base3 = nn.Sequential(
Mixed_4b(),
Mixed_4c(),
Mixed_4d(),
Mixed_4e(),
Mixed_4f(),
)
self.maxt4 = nn.MaxPool3d(kernel_size=(2, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
self.maxp4 = nn.MaxPool3d(kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=(0, 0, 0))
self.base4 = nn.Sequential(
Mixed_5b(),
Mixed_5c(),
)
self.convtsp1 = nn.Sequential(
nn.Conv3d(256, 64, kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.ConvTranspose3d(64, 64, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.convtsp11 = nn.Sequential(
#nn.ConvTranspose3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.convtsp2 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.convtsp22 = nn.Sequential(
#nn.ConvTranspose3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.convtsp3 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=(1, 3, 3), stride=1, padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.convtsp33 = nn.Sequential(
#nn.ConvTranspose3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
)
self.convtsp4 = nn.Sequential(
nn.ConvTranspose3d(64, 64, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.Conv3d(64, 64, kernel_size=(2, 1, 1), stride=(2, 1, 1), bias=False),
nn.BatchNorm3d(64, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.ConvTranspose3d(64, 4, kernel_size=1, stride=1, bias=False),
nn.BatchNorm3d(4, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.Conv3d(4, 4, kernel_size=(2, 1, 1), stride=(2, 1, 1), bias=False),
nn.BatchNorm3d(4, eps=1e-3, momentum=0.001, affine=True),
nn.ReLU(),
nn.ConvTranspose3d(4, 4, kernel_size=(1, 4, 4), stride=(1, 2, 2), padding=(0, 1, 1), bias=False),
nn.Conv3d(4, 1, kernel_size=1, stride=1, bias=True),
nn.Sigmoid(),
)
def forward(self, x):
y3 = self.base1(x)
y = self.maxp2(y3)
y2 = self.base2(y)
y = self.maxp3(y2)
y1 = self.base3(y)
y = self.maxt4(y1)
y = self.maxp4(y)
y0 = self.base4(y)
z = self.convtsp1(y0)
z = self.convtsp11(z)
z = self.convtsp2(z)
z = self.convtsp22(z)
z = self.convtsp3(z)
z = self.convtsp33(z)
z = self.convtsp4(z)
z = z.view(z.size(0), z.size(3), z.size(4))
return z
class BasicConv3d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(BasicConv3d, self).__init__()
self.conv = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding,
bias=False)
self.bn = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class SepConv3d(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
super(SepConv3d, self).__init__()
self.conv_s = nn.Conv3d(in_planes, out_planes, kernel_size=(1, kernel_size, kernel_size),
stride=(1, stride, stride), padding=(0, padding, padding), bias=False)
self.bn_s = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
self.relu_s = nn.ReLU()
self.conv_t = nn.Conv3d(out_planes, out_planes, kernel_size=(kernel_size, 1, 1), stride=(stride, 1, 1),
padding=(padding, 0, 0), bias=False)
self.bn_t = nn.BatchNorm3d(out_planes, eps=1e-3, momentum=0.001, affine=True)
self.relu_t = nn.ReLU()
def forward(self, x):
x = self.conv_s(x)
x = self.bn_s(x)
x = self.relu_s(x)
x = self.conv_t(x)
x = self.bn_t(x)
x = self.relu_t(x)
return x
class Mixed_3b(nn.Module):
def __init__(self):
super(Mixed_3b, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 80, kernel_size=1, stride=1),
SepConv3d(80, 96, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 16, kernel_size=1, stride=1),
SepConv3d(16, 16, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 32, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_3c(nn.Module):
def __init__(self):
super(Mixed_3c, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 96, kernel_size=1, stride=1),
SepConv3d(96, 48, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 32, kernel_size=1, stride=1),
SepConv3d(32, 48, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 48, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4b(nn.Module):
def __init__(self):
super(Mixed_4b, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 56, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
SepConv3d(48, 64, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 16, kernel_size=1, stride=1),
SepConv3d(16, 36, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 36, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4c(nn.Module):
def __init__(self):
super(Mixed_4c, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 112, kernel_size=1, stride=1),
SepConv3d(112, 64, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 24, kernel_size=1, stride=1),
SepConv3d(24, 32, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 32, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4d(nn.Module):
def __init__(self):
super(Mixed_4d, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 112, kernel_size=1, stride=1),
SepConv3d(112, 64, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 24, kernel_size=1, stride=1),
SepConv3d(24, 32, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 32, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4e(nn.Module):
def __init__(self):
super(Mixed_4e, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 96, kernel_size=1, stride=1),
SepConv3d(96, 64, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 32, kernel_size=1, stride=1),
SepConv3d(32, 48, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 32, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_4f(nn.Module):
def __init__(self):
super(Mixed_4f, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 96, kernel_size=1, stride=1),
SepConv3d(96, 64, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 32, kernel_size=1, stride=1),
SepConv3d(32, 32, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 32, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_5b(nn.Module):
def __init__(self):
super(Mixed_5b, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 80, kernel_size=1, stride=1),
SepConv3d(80, 48, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 32, kernel_size=1, stride=1),
SepConv3d(32, 32, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
class Mixed_5c(nn.Module):
def __init__(self):
super(Mixed_5c, self).__init__()
self.branch0 = nn.Sequential(
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
self.branch1 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
SepConv3d(48, 64, kernel_size=3, stride=1, padding=1),
)
self.branch2 = nn.Sequential(
BasicConv3d(192, 48, kernel_size=1, stride=1),
SepConv3d(48, 64, kernel_size=3, stride=1, padding=1),
)
self.branch3 = nn.Sequential(
nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1),
BasicConv3d(192, 64, kernel_size=1, stride=1),
)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
out = torch.cat((x0, x1, x2, x3), 1)
return out
模型输入为:(1, 3, 32, 56, 96),模型输入为:(1, 3, 32, 52, 96)。
4 训练代码
修改模型代码后,我们也需要对训练代码进行修改,修改后的代码如下
import sys
import os
import numpy as np
import cv2
import time
from datetime import timedelta
import torch
#from model import TASED_v2
#from model_v93 import TASED_v2
#from model_lightweight import TASED_v2
from model_lightweight2 import TASED_v2
from loss import KLDLoss
from dataset import DHF1KDataset, InfiniteDataLoader
from itertools import islice
from torch.utils.data import Dataset, DataLoader
def load_model(model_weight_path):
model = TASED_v2()
if os.path.isfile(model_weight_path):
print('loading weight file')
weight_dict = torch.load(model_weight_path)
model_dict = model.state_dict()
for name, param in weight_dict.items():
if 'module' in name:
name = '.'.join(name.split('.')[1:])
if name in model_dict:
if param.size() == model_dict[name].size():
model_dict[name].copy_(param)
else:
print(' size? ' + name, param.size(), model_dict[name].size())
else:
print(' name? ' + name)
print(' loaded')
else:
print(f'weight file {model_weight_path} not found')
return model
def train_epoch(epoch, model, data_loader, optimizer, loss_fn, device):
print('Training Epoch: {}'.format(epoch))
model.to(device)
model.train()
losses = 0
for batch_idx, (data, target) in enumerate(data_loader):
data, target = data.to(device), target.to(device)
output = model(data)
optimizer.zero_grad()
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
losses += loss.item()
if batch_idx % 10 == 0:
print('train epoch: {}, lr: {:f} , batch_index: {}, loss: {}'.format(epoch, optimizer.param_groups[0]['lr'], batch_idx, loss.item()))
loss_average = losses / len(list(data_loader))
return loss_average
def main():
''' concise script for training '''
# optional two command-line arguments
#path_indata = './DHF1K'
path_indata = 'E:/dataset/DHF1k'
path_output = './output'
if len(sys.argv) > 1:
path_indata = sys.argv[1]
if len(sys.argv) > 2:
path_output = sys.argv[2]
# we checked that using only 2 gpus is enough to produce similar results
num_gpu = 1
pile = 5
batch_size = 4
num_iters = 1000
len_temporal = 32
tasednet_file_weight = './model_weights/epoch199_loss_1.398715.pt'
s3d_file_weight = './S3D_kinetics400.pt'
path_output = os.path.join(path_output, time.strftime("%m-%d_%H-%M-%S"))
if not os.path.isdir(path_output):
os.makedirs(path_output)
if tasednet_file_weight != '':
model = load_model(tasednet_file_weight)
else:
model = TASED_v2()
# load the weight file and copy the parameters
if os.path.isfile(s3d_file_weight):
print ('loading s3d weight file')
weight_dict = torch.load(s3d_file_weight)
model_dict = model.state_dict()
for name, param in weight_dict.items():
if 'module' in name:
name = '.'.join(name.split('.')[1:])
if 'base.' in name:
bn = int(name.split('.')[1])
sn_list = [0, 5, 8, 14]
sn = sn_list[0]
if bn >= sn_list[1] and bn < sn_list[2]:
sn = sn_list[1]
elif bn >= sn_list[2] and bn < sn_list[3]:
sn = sn_list[2]
elif bn >= sn_list[3]:
sn = sn_list[3]
name = '.'.join(name.split('.')[2:])
name = 'base%d.%d.'%(sn_list.index(sn)+1, bn-sn)+name
if name in model_dict:
if param.size() == model_dict[name].size():
model_dict[name].copy_(param)
else:
print (' size? ' + name, param.size(), model_dict[name].size())
else:
print (' name? ' + name)
print (' loaded')
else:
print ('s3d weight file?')
# parameter setting for fine-tuning
# params = []
# for key, value in dict(model.named_parameters()).items():
# if 'convtsp' in key:
# params += [{'params':[value], 'key':key+'(new)'}]
# else:
# params += [{'params':[value], 'lr':0.001, 'key':key}]
#
# optimizer = torch.optim.SGD(params, lr=0.1, momentum=0.9, weight_decay=2e-7)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, betas=(0.9, 0.98), eps=1e-9)
#lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 30, 40, 50, 60], gamma=0.5)
criterion = KLDLoss()
model = model.cuda()
torch.backends.cudnn.benchmark = False
model.train()
train_dataset = DHF1KDataset(path_indata, len_temporal)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
epochs = 400
for epoch in range(epochs):
loss = train_epoch(epoch, model, train_loader, optimizer, criterion, torch.device('cuda:0'))
print(f'epoch: {epoch} complete, loss: {loss}')
save_model_path = os.path.join(path_output, f'epoch{epoch}_loss_{loss:04f}.pt')
torch.save(model.state_dict(), save_model_path)
if __name__ == '__main__':
main()
5 导出模型为苹果的coreml模型
使用以下代码导出重新训练后的轻量化TasedNet模型为苹果推理所需的coreml模型
import os
import time
import numpy as np
import torch
#from model_v93 import TASED
#from model_lightweight import TASED_v2
from model_lightweight2 import TASED_v2
import coremltools as ct
def load_model(model_path):
model = TASED_v2()
# load the weight file and copy the parameters
if os.path.isfile(model_path):
time1 = time.time()*1000
print ('loading weight file')
weight_dict = torch.load(model_path)
time2 = time.time()*1000
print("loading model takes %d millisenconds."%(time2-time1))
model_dict = model.state_dict()
for name, param in weight_dict.items():
if 'module' in name:
name = '.'.join(name.split('.')[1:])
if name in model_dict:
if param.size() == model_dict[name].size():
model_dict[name].copy_(param)
else:
print (' size? ' + name, param.size(), model_dict[name].size())
else:
print (' name? ' + name)
print (' loaded')
else:
print ('weight file?')
return model
if __name__ == '__main__':
model_weight_path = './model_weights/epoch399_loss_1.206857.pt'
model = load_model(model_weight_path)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
model = model.to(device)
torch.backends.cudnn.benchmark = False
model.eval()
# for name in model.state_dict():
# print(name)
# tracing
example_input = torch.randn(1, 3, 32, 56, 96).to(device)
traced_script_module = torch.jit.trace(model, example_input)
with torch.no_grad():
outputs =traced_script_module(example_input)
print(traced_script_module)
coreml_input = ct.TensorType(name='input', shape=(1, 3, 32, 56, 96), dtype=np.float32)
coreml_model = ct.convert(traced_script_module, inputs=[coreml_input], minimum_deployment_target=ct.target.macOS11)
coreml_model.save("TasedNet.mlmodel")
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:视频显著性检测模型TasedNet在移动端的轻量化设计
原文链接:https://www.stubbornhuang.com/3109/
发布于:2024年11月25日 19:46:30
修改于:2024年11月25日 19:46:30
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
52