Computer vision YOLO11 model

This commit is contained in:
Zuxin Dai
2025-01-24 10:04:59 -05:00
commit 2e6f7a8a1f
22 changed files with 476 additions and 0 deletions

80
Test_logic_track.py Normal file
View File

@@ -0,0 +1,80 @@
import os
import cv2
from ultralytics import YOLO
# 输入和输出视频路径
# video_path = r'D:\AIM\pecan\OneDrive_2_2024-11-7\G5 Flex 01 8-5-2024, 1.20.12pm EDT - 8-5-2024, 1.23.22pm EDT.mp4'
# video_path_out = r'D:\AIM\pecan\G5 Flex 01 8-5-2024_out.mp4'
video_path = r'D:\AIM\pecan\GH014359.mp4'
video_path_out = r'D:\AIM\pecan\GH014359_out.mp4'
# 加载 YOLO 模型
model = YOLO(r"D:\AIM\pecan\runs\detect\train2\weights\best.pt") # 加载自定义模型
# 初始化 VideoWriter 用于保存输出视频
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(video_path_out, fourcc, fps, (width, height))
# 字典,用于跟踪每个核桃的状态
walnut_states = {} # 格式: {ID: "状态"}
# 定义类别标签
class_labels = {
0: "circumferential",
1: "cracked",
2: "crushed",
3: "longitudinal",
4: "open",
5: "uncracked"
}
# 需要分配 ID 的类别
id_tracked_classes = ["cracked", "uncracked"]
# 使用 BoT-SORT 进行目标跟踪
results = model.track(source=video_path, conf=0.5, tracker='botsort.yaml', show=False)
for result in results:
frame = result.orig_img # 当前帧
detections = result.boxes # 检测框信息
# 处理每个检测框
for box in detections:
x1, y1, x2, y2 = map(int, box.xyxy[0]) # 检测框坐标
obj_id = int(box.id) if box.id is not None else -1 # 跟踪目标ID
class_id = int(box.cls) # 类别ID
score = box.conf # 置信度
# 获取检测框对应的类别标签
label = class_labels.get(class_id, "unknown")
# 仅对需要分配ID的类别更新核桃状态
if label in id_tracked_classes:
if obj_id not in walnut_states:
walnut_states[obj_id] = label
else:
# 一旦检测到“cracked”状态保持为“cracked”
if walnut_states[obj_id] != "cracked":
walnut_states[obj_id] = label
display_text = f"ID {obj_id} | {walnut_states[obj_id]}"
else:
# 非分配ID的类别仅显示类别标签
display_text = label
# 绘制检测框和标签
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(frame, display_text, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
# 将处理好的帧写入输出视频
out.write(frame)
# 释放资源
cap.release()
out.release()
print("视频处理完成,结果已保存。")