Files
huojv/yolo_test.py
2025-11-04 11:32:16 +08:00

217 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
from utils.get_image import GetImage
from ultralytics import YOLO
from config import config_manager
import os
# 检查模型文件是否存在
model_path = r"best0.pt"
if not os.path.exists(model_path):
print(f"❌ 模型文件不存在: {model_path}")
exit(1)
# 加载YOLO模型
try:
model = YOLO(model_path).to('cuda')
print(f"✅ 模型加载成功: {model_path}")
except Exception as e:
print(f"❌ 模型加载失败: {e}")
exit(1)
def yolo_shibie(im_PIL, im_opencv_rgb, raw_frame_bgr, detections, model, show_original=True):
"""
YOLO识别函数
:param im_PIL: PIL图像对象
:param im_opencv_rgb: RGB格式的OpenCV图像裁剪后
:param raw_frame_bgr: 原始BGR格式的OpenCV图像未裁剪与raw_frame.jpg一致
:param detections: 检测结果字典
:param model: YOLO模型
:param show_original: 是否同时显示原始帧
:return: 更新后的detections字典如果用户退出则返回None
"""
if im_PIL is None:
return detections
try:
results = model(im_PIL)
result = results[0]
# ✅ 获取绘制好框的图像RGB格式
frame_with_boxes_rgb = result.plot()
# ✅ 转换为BGR格式用于OpenCV显示
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
# 显示画面
if show_original and raw_frame_bgr is not None:
# 同时显示原始帧和检测结果(并排显示)
# 调整原始帧大小以匹配裁剪后的检测结果
h, w = frame_with_boxes_bgr.shape[:2]
# 裁剪原始帧与get_frame的处理一致30:30+720, 0:1280
raw_height, raw_width = raw_frame_bgr.shape[:2]
crop_top = 30
crop_bottom = min(crop_top + h, raw_height)
crop_right = min(w, raw_width)
raw_cropped = raw_frame_bgr[crop_top:crop_bottom, 0:crop_right]
# 如果尺寸不匹配,调整原始帧大小
if raw_cropped.shape[:2] != (h, w):
raw_cropped = cv2.resize(raw_cropped, (w, h))
# 并排显示:原始帧(左) | 检测结果(右)
# 原始帧已经是BGR格式检测结果也是BGR格式可以直接拼接
combined = cv2.hconcat([raw_cropped, frame_with_boxes_bgr])
cv2.imshow("原始帧 (左, 与raw_frame.jpg一致) | YOLO检测结果 (右)", combined)
else:
# 只显示检测结果
cv2.imshow("YOLO实时检测", frame_with_boxes_bgr)
# ✅ 提取检测信息
if result.boxes is not None and len(result.boxes.xyxy) > 0:
for i in range(len(result.boxes.xyxy)):
try:
left = float(result.boxes.xyxy[i][0])
top = float(result.boxes.xyxy[i][1])
right = float(result.boxes.xyxy[i][2])
bottom = float(result.boxes.xyxy[i][3])
cls_id = int(result.boxes.cls[i])
label = result.names[cls_id]
if label in ['center', 'next', 'npc1', 'npc2', 'npc3', 'npc4', 'boss', 'zhaozi']:
player_x = int(left + (right - left) / 2) + 3
player_y = int(top + (bottom - top) / 2) + 40
detections[label] = [player_x, player_y]
elif label in ['daojv', 'gw']:
player_x = int(left + (right - left) / 2) + 3
player_y = int(top + (bottom - top) / 2) + 40
# 确保列表存在
if label not in detections:
detections[label] = []
detections[label].append([player_x, player_y])
except Exception as e:
print(f"⚠️ 处理检测框时出错: {e}")
continue
except Exception as e:
print(f"⚠️ YOLO检测出错: {e}")
return detections
def main():
"""主函数"""
print("="*60)
print("YOLO实时检测测试")
print("="*60)
# 从配置加载采集卡设置
active_group = config_manager.get_active_group()
if active_group is None:
print("⚠️ 没有活动的配置组,使用默认设置")
print("提示: 可以运行 python gui_config.py 设置配置")
cam_index = 0
width = 1920
height = 1080
else:
print(f"📋 使用配置组: {active_group['name']}")
cam_index = active_group['camera_index']
width = active_group['camera_width']
height = active_group['camera_height']
print(f" 采集卡索引: {cam_index}")
print(f" 分辨率: {width}x{height}")
print()
# 初始化采集卡
print("🔧 正在初始化采集卡...")
get_image = GetImage(
cam_index=cam_index,
width=width,
height=height
)
if get_image.cap is None:
print("❌ 采集卡初始化失败")
print("请检查:")
print("1. 采集卡是否正确连接")
print("2. 采集卡索引是否正确")
print("3. 采集卡驱动是否安装")
return
print("✅ 采集卡初始化成功")
print("'q' 或 ESC 键退出测试")
print()
try:
frame_count = 0
show_original = True # 默认同时显示原始帧和检测结果
print("提示: 按 'o' 键切换是否显示原始帧对比")
print()
while True:
# 获取帧
frame_data = get_image.get_frame()
if frame_data is None:
print("⚠️ 无法获取帧,跳过...")
continue
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
# im_opencv_rgb 已经是RGB格式经过BGR2RGB转换
im_opencv_rgb, im_PIL = frame_data
if im_PIL is None:
print("⚠️ PIL图像为空跳过...")
continue
# 获取原始BGR帧与test_capture_card.py保存的raw_frame.jpg一致
raw_frame_bgr = None
if get_image.cap is not None and get_image.frame is not None:
raw_frame_bgr = get_image.frame.copy() # 原始BGR格式未裁剪
# 初始化检测结果字典
detections = {
'center': None, 'next': None,
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
'boss': None, 'zhaozi': None,
'daojv': [], 'gw': []
}
# 执行YOLO检测
detections = yolo_shibie(im_PIL, im_opencv_rgb, raw_frame_bgr, detections, model, show_original)
# 检查按键
key = cv2.waitKey(1) & 0xFF
if key in [27, ord('q'), ord('Q')]:
print("\n用户退出")
break
elif key == ord('o') or key == ord('O'):
show_original = not show_original
print(f"切换显示模式: {'原始帧对比' if show_original else '仅检测结果'}")
frame_count += 1
if frame_count % 30 == 0: # 每30帧打印一次
print(f"📊 已处理 {frame_count}")
# 打印有检测到的目标
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
if detected_items:
print(f" 检测到: {detected_items}")
except KeyboardInterrupt:
print("\n\n用户中断测试")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
finally:
# 清理资源
get_image.release()
cv2.destroyAllWindows()
print("🔚 测试结束")
if __name__ == "__main__":
main()