Computer vision YOLO11 model
This commit is contained in:
80
Test_logic_track.py
Normal file
80
Test_logic_track.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import os
|
||||
import cv2
|
||||
from ultralytics import YOLO
|
||||
|
||||
# 输入和输出视频路径
|
||||
# video_path = r'D:\AIM\pecan\OneDrive_2_2024-11-7\G5 Flex 01 8-5-2024, 1.20.12pm EDT - 8-5-2024, 1.23.22pm EDT.mp4'
|
||||
# video_path_out = r'D:\AIM\pecan\G5 Flex 01 8-5-2024_out.mp4'
|
||||
video_path = r'D:\AIM\pecan\GH014359.mp4'
|
||||
video_path_out = r'D:\AIM\pecan\GH014359_out.mp4'
|
||||
|
||||
# 加载 YOLO 模型
|
||||
model = YOLO(r"D:\AIM\pecan\runs\detect\train2\weights\best.pt") # 加载自定义模型
|
||||
|
||||
# 初始化 VideoWriter 用于保存输出视频
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
out = cv2.VideoWriter(video_path_out, fourcc, fps, (width, height))
|
||||
|
||||
# 字典,用于跟踪每个核桃的状态
|
||||
walnut_states = {} # 格式: {ID: "状态"}
|
||||
|
||||
# 定义类别标签
|
||||
class_labels = {
|
||||
0: "circumferential",
|
||||
1: "cracked",
|
||||
2: "crushed",
|
||||
3: "longitudinal",
|
||||
4: "open",
|
||||
5: "uncracked"
|
||||
}
|
||||
|
||||
# 需要分配 ID 的类别
|
||||
id_tracked_classes = ["cracked", "uncracked"]
|
||||
|
||||
# 使用 BoT-SORT 进行目标跟踪
|
||||
results = model.track(source=video_path, conf=0.5, tracker='botsort.yaml', show=False)
|
||||
|
||||
for result in results:
|
||||
frame = result.orig_img # 当前帧
|
||||
detections = result.boxes # 检测框信息
|
||||
|
||||
# 处理每个检测框
|
||||
for box in detections:
|
||||
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 检测框坐标
|
||||
obj_id = int(box.id) if box.id is not None else -1 # 跟踪目标ID
|
||||
class_id = int(box.cls) # 类别ID
|
||||
score = box.conf # 置信度
|
||||
|
||||
# 获取检测框对应的类别标签
|
||||
label = class_labels.get(class_id, "unknown")
|
||||
|
||||
# 仅对需要分配ID的类别更新核桃状态
|
||||
if label in id_tracked_classes:
|
||||
if obj_id not in walnut_states:
|
||||
walnut_states[obj_id] = label
|
||||
else:
|
||||
# 一旦检测到“cracked”,状态保持为“cracked”
|
||||
if walnut_states[obj_id] != "cracked":
|
||||
walnut_states[obj_id] = label
|
||||
|
||||
display_text = f"ID {obj_id} | {walnut_states[obj_id]}"
|
||||
else:
|
||||
# 非分配ID的类别,仅显示类别标签
|
||||
display_text = label
|
||||
|
||||
# 绘制检测框和标签
|
||||
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
||||
cv2.putText(frame, display_text, (x1, y1 - 10),
|
||||
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
|
||||
|
||||
# 将处理好的帧写入输出视频
|
||||
out.write(frame)
|
||||
|
||||
# 释放资源
|
||||
cap.release()
|
||||
out.release()
|
||||
print("视频处理完成,结果已保存。")
|
||||
Reference in New Issue
Block a user