视频智能分镜/视频转场检测模型TransNetV2导出为onnx和基于onnxruntime的python推理类
1 TransNetV2是一个视频镜头转场检测模型
TransNetV2是一个视频镜头转场检测模型,模型既保持了较高的准确性,又兼顾了快速的推理速度。
Github:https://github.com/soCzech/TransNetV2
1.1 TransNetV2导出为onnx
转换模型对应的pytorch权重,在官方仓库的inference-pytorch文件夹下新建一个export_to_onnx.py文件,并在该文件中添加一下代码,
import os
import torch
import cv2
import numpy as np
from transnetv2_pytorch import TransNetV2
def load_model(weights_path, device="cuda:0"):
if not os.path.exists(weights_path):
raise FileNotFoundError(f"[TransNetV2] ERROR: {weights_path} is not a file.")
model = TransNetV2()
state_dict = torch.load(weights_path)
model.load_state_dict(state_dict)
model.eval().to(device)
return model
def export_to_onnx(model_weight_path, output_onnx_path):
device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
# 加载模型
model = load_model(model_weight_path, device)
# 导出模型为onnx
input = torch.zeros(1, 100, 27, 48, 3, dtype=torch.uint8).to(device)
torch.onnx.export(model,
input,
output_onnx_path,
verbose=True,
input_names=['input'],
output_names=['single_frame_pred', 'all_frame_pred'])
print(f'convert successful')
if __name__ == '__main__':
model_weights_path = "./transnetv2-pytorch-weights.pth"
output_onnx_path = './transnetv2.onnx'
export_to_onnx(model_weights_path, output_onnx_path)
运行上述代码,即可导出pytorch模型为onnx,位置在inference-pytorch文件夹下的transnetv2.onnx。
1.2 基于导出的onnx模型进行推理
同样的,在inference-pytorch文件夹下新建一个inference_video_by_onnx.py文件,在该文件中增加以下代码,
import copy
import os
import torch
import cv2
import numpy as np
import onnxruntime as ort
from moviepy.editor import VideoFileClip
class TransNetV2_Onnx():
def __init__(self, onnx_path, window_size= 100, overlap_size= 20):
self.onnx_path = onnx_path
if not os.path.exists(onnx_path):
raise FileNotFoundError(f'{onnx_path} is not found')
self.ort_session = ort.InferenceSession(onnx_path)
if self.ort_session is None:
raise Exception(f'self.ort_session is None')
self.window_size = window_size
self.overlap_size = overlap_size
def inference_video(self, input_video_path, output_video_dir, save_scene_txt=False):
if not os.path.exists(input_video_path):
raise FileNotFoundError(f'{input_video_path} is not found')
if not os.path.exists(output_video_dir):
os.makedirs(output_video_dir)
print(f'inference_video: {input_video_path}')
# 打开视频文件
cap = cv2.VideoCapture(input_video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f'video frame count:{frame_count}')
all_predictions = []
window_size = self.window_size # 100
step_size = 50 # 固定步长为50
overlap_frame_size = 25
# 先读取第一帧用于填充
ret, first_frame = cap.read()
if not ret:
raise Exception("Cannot read first frame")
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
first_frame = cv2.resize(first_frame, (48, 27))
# 重置视频读取位置
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
current_frame = 0
while current_frame < frame_count:
# 构建100帧的窗口
frames = []
# 如果是开始部分,需要用第一帧填充
if current_frame == 0:
frames.extend([first_frame] * overlap_frame_size)
# 读取实际帧
actual_frames_to_read = min(window_size - len(frames), frame_count - current_frame)
for _ in range(actual_frames_to_read):
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = cv2.resize(frame, (48, 27))
frames.append(frame)
else:
break
actual_frames = len(frames)
# 如果不足100帧,用最后一帧填充
if actual_frames < window_size:
last_frame = frames[-1]
padding_size = window_size - actual_frames
frames.extend([last_frame] * padding_size)
# 执行推理
input_video = np.array(frames, dtype=np.uint8)
input_video = np.expand_dims(input_video, axis=0)
single_frame_pred, _ = self.ort_session.run(None, {'input': input_video})
single_frame_pred = 1 / (1 + np.exp(-single_frame_pred))
# 保存中间50帧的预测结果(第25帧到第74帧)
predictions_to_save = single_frame_pred[0, overlap_frame_size:overlap_frame_size+step_size, 0]
# 如果是最后一个窗口,可能需要截断
if current_frame + step_size > frame_count:
remaining_frames = frame_count - current_frame
predictions_to_save = predictions_to_save[:remaining_frames]
all_predictions.append(predictions_to_save)
print(f"\r处理视频帧 {min(len(all_predictions) * 50, frame_count)}/{frame_count}", end="")
# 移动到下一个窗口的起始位置
current_frame += step_size
if current_frame < frame_count:
# 回退25帧以保持连续性
cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame - overlap_frame_size)
print("")
cap.release()
# 合并所有预测结果
single_frame_pred = np.concatenate(all_predictions)
assert len(single_frame_pred) == frame_count, f"Predictions count {len(single_frame_pred)} doesn't match frame count {frame_count}"
# 推理场景
scenes = predictions_to_scenes(single_frame_pred)
print(f'divide scene result: {scenes}')
if save_scene_txt:
input_video_name = os.path.basename(input_video_path).split('.')[0]
output_txt_path = os.path.join(output_video_dir, f'{input_video_name}_scenes.txt')
save_scenes_results_to_txt(scenes, output_txt_path)
# 可视化视频
visualize_video(input_video_path, scenes, output_video_dir)
print(f'process completed')
def predictions_to_scenes(predictions: np.ndarray, threshold: float = 0.5):
predictions = (predictions > threshold).astype(np.uint8)
scenes = []
t, t_prev, start = -1, 0, 0
for i, t in enumerate(predictions):
if t_prev == 1 and t == 0:
start = i
if t_prev == 0 and t == 1 and i != 0:
scenes.append([start, i])
t_prev = t
if t == 0:
scenes.append([start, i])
# just fix if all predictions are 1
if len(scenes) == 0:
return np.array([[0, len(predictions) - 1]], dtype=np.int32)
return np.array(scenes, dtype=np.int32)
def save_scenes_results_to_txt(scenes, output_txt_path):
with open(output_txt_path, 'w') as f:
for i, scene_index in enumerate(scenes):
start_frame = scene_index[0]
end_frame = scene_index[1]
f.write(f'{start_frame} {end_frame}\n')
print(f'save scenes results to {output_txt_path}')
def convert_to_h264(input_video_path):
if not os.path.exists(input_video_path):
raise Exception(f'{input_video_path} is not exist')
# 在output_video_path修改名称,文件名加上moviepy_前缀
new_output_video_path = input_video_path.replace('.mp4', '_h264.mp4')
# Load the input video
video_clip = VideoFileClip(input_video_path)
# Write the output video with H.264 encoding
video_clip.write_videofile(new_output_video_path, codec='libx264')
def visualize_video(input_video_path, scenes, output_video_dir):
if not os.path.exists(input_video_path):
raise Exception(f'{input_video_path} is not exist')
if not os.path.exists(output_video_dir):
os.makedirs(output_video_dir)
input_video_name = os.path.basename(input_video_path).split('.')[0]
output_video_path = os.path.join(output_video_dir, f'{input_video_name}_camera_detect.mp4')
video_read_cap = cv2.VideoCapture(input_video_path)
frame_video_width = int(video_read_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_video_height = int(video_read_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_fps = int(video_read_cap.get(cv2.CAP_PROP_FPS))
frame_count = int(video_read_cap.get(cv2.CAP_PROP_FRAME_COUNT))
video_write_cap = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*'avc1'), frame_fps, (frame_video_width, frame_video_height))
for i, scene_index in enumerate(scenes):
start_frame = scene_index[0]
end_frame = scene_index[1]
num_frames = end_frame - start_frame
scene_str = f'scene {i}'
# 设置视频读取的起始位置
video_read_cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
for _ in range(num_frames):
ret, frame = video_read_cap.read()
if ret:
cv2.putText(frame, scene_str, (50, 50), cv2.FONT_HERSHEY_COMPLEX, 1.0, (0, 0, 255), 1, cv2.LINE_AA)
video_write_cap.write(frame)
else:
break
video_read_cap.release()
video_write_cap.release()
print(f'video: {output_video_dir} write success')
# 使用moviepy将视频转换为H.264编码
convert_to_h264(output_video_path)
if __name__ == '__main__':
onnx_path = "./transnetv2.onnx"
input_video_path = './inference_test/input_videos/input_video.mp4'
output_video_dir = './inference_test/output_videos_onnx'
transnetv2_onnx = TransNetV2_Onnx(onnx_path, window_size= 100, overlap_size= 20)
transnetv2_onnx.inference_video(input_video_path, output_video_dir, save_scene_txt=True)
更改输入视频路径和输出视频路径,即可在生成对应的输入视频文件的输出文件,并在输出文件的左上角使用红色的scene {num}进行标志为当前帧是第几个场景。
本文作者:StubbornHuang
版权声明:本文为站长原创文章,如果转载请注明原文链接!
原文标题:视频智能分镜/视频转场检测模型TransNetV2导出为onnx和基于onnxruntime的python推理类
原文链接:https://www.stubbornhuang.com/3128/
发布于:2025年03月03日 19:40:03
修改于:2025年03月05日 11:37:37
声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。
评论
52