217 lines
7.9 KiB
Python
217 lines
7.9 KiB
Python
import cv2
|
||
from utils.get_image import GetImage
|
||
from ultralytics import YOLO
|
||
from config import config_manager
|
||
import os
|
||
|
||
# 检查模型文件是否存在
|
||
model_path = r"best0.pt"
|
||
if not os.path.exists(model_path):
|
||
print(f"❌ 模型文件不存在: {model_path}")
|
||
exit(1)
|
||
|
||
# 加载YOLO模型
|
||
try:
|
||
model = YOLO(model_path).to('cuda')
|
||
print(f"✅ 模型加载成功: {model_path}")
|
||
except Exception as e:
|
||
print(f"❌ 模型加载失败: {e}")
|
||
exit(1)
|
||
|
||
def yolo_shibie(im_PIL, im_opencv_rgb, raw_frame_bgr, detections, model, show_original=True):
|
||
"""
|
||
YOLO识别函数
|
||
:param im_PIL: PIL图像对象
|
||
:param im_opencv_rgb: RGB格式的OpenCV图像(裁剪后)
|
||
:param raw_frame_bgr: 原始BGR格式的OpenCV图像(未裁剪,与raw_frame.jpg一致)
|
||
:param detections: 检测结果字典
|
||
:param model: YOLO模型
|
||
:param show_original: 是否同时显示原始帧
|
||
:return: 更新后的detections字典,如果用户退出则返回None
|
||
"""
|
||
if im_PIL is None:
|
||
return detections
|
||
|
||
try:
|
||
results = model(im_PIL)
|
||
result = results[0]
|
||
|
||
# ✅ 获取绘制好框的图像(RGB格式)
|
||
frame_with_boxes_rgb = result.plot()
|
||
|
||
# ✅ 转换为BGR格式用于OpenCV显示
|
||
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
|
||
|
||
# 显示画面
|
||
if show_original and raw_frame_bgr is not None:
|
||
# 同时显示原始帧和检测结果(并排显示)
|
||
# 调整原始帧大小以匹配裁剪后的检测结果
|
||
h, w = frame_with_boxes_bgr.shape[:2]
|
||
# 裁剪原始帧(与get_frame的处理一致:30:30+720, 0:1280)
|
||
raw_height, raw_width = raw_frame_bgr.shape[:2]
|
||
crop_top = 30
|
||
crop_bottom = min(crop_top + h, raw_height)
|
||
crop_right = min(w, raw_width)
|
||
raw_cropped = raw_frame_bgr[crop_top:crop_bottom, 0:crop_right]
|
||
|
||
# 如果尺寸不匹配,调整原始帧大小
|
||
if raw_cropped.shape[:2] != (h, w):
|
||
raw_cropped = cv2.resize(raw_cropped, (w, h))
|
||
|
||
# 并排显示:原始帧(左) | 检测结果(右)
|
||
# 原始帧已经是BGR格式,检测结果也是BGR格式,可以直接拼接
|
||
combined = cv2.hconcat([raw_cropped, frame_with_boxes_bgr])
|
||
cv2.imshow("原始帧 (左, 与raw_frame.jpg一致) | YOLO检测结果 (右)", combined)
|
||
else:
|
||
# 只显示检测结果
|
||
cv2.imshow("YOLO实时检测", frame_with_boxes_bgr)
|
||
|
||
# ✅ 提取检测信息
|
||
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
||
for i in range(len(result.boxes.xyxy)):
|
||
try:
|
||
left = float(result.boxes.xyxy[i][0])
|
||
top = float(result.boxes.xyxy[i][1])
|
||
right = float(result.boxes.xyxy[i][2])
|
||
bottom = float(result.boxes.xyxy[i][3])
|
||
cls_id = int(result.boxes.cls[i])
|
||
label = result.names[cls_id]
|
||
|
||
if label in ['center', 'next', 'npc1', 'npc2', 'npc3', 'npc4', 'boss', 'zhaozi']:
|
||
player_x = int(left + (right - left) / 2) + 3
|
||
player_y = int(top + (bottom - top) / 2) + 40
|
||
detections[label] = [player_x, player_y]
|
||
elif label in ['daojv', 'gw']:
|
||
player_x = int(left + (right - left) / 2) + 3
|
||
player_y = int(top + (bottom - top) / 2) + 40
|
||
# 确保列表存在
|
||
if label not in detections:
|
||
detections[label] = []
|
||
detections[label].append([player_x, player_y])
|
||
except Exception as e:
|
||
print(f"⚠️ 处理检测框时出错: {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ YOLO检测出错: {e}")
|
||
|
||
return detections
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("="*60)
|
||
print("YOLO实时检测测试")
|
||
print("="*60)
|
||
|
||
# 从配置加载采集卡设置
|
||
active_group = config_manager.get_active_group()
|
||
|
||
if active_group is None:
|
||
print("⚠️ 没有活动的配置组,使用默认设置")
|
||
print("提示: 可以运行 python gui_config.py 设置配置")
|
||
cam_index = 0
|
||
width = 1920
|
||
height = 1080
|
||
else:
|
||
print(f"📋 使用配置组: {active_group['name']}")
|
||
cam_index = active_group['camera_index']
|
||
width = active_group['camera_width']
|
||
height = active_group['camera_height']
|
||
|
||
print(f" 采集卡索引: {cam_index}")
|
||
print(f" 分辨率: {width}x{height}")
|
||
print()
|
||
|
||
# 初始化采集卡
|
||
print("🔧 正在初始化采集卡...")
|
||
get_image = GetImage(
|
||
cam_index=cam_index,
|
||
width=width,
|
||
height=height
|
||
)
|
||
|
||
if get_image.cap is None:
|
||
print("❌ 采集卡初始化失败")
|
||
print("请检查:")
|
||
print("1. 采集卡是否正确连接")
|
||
print("2. 采集卡索引是否正确")
|
||
print("3. 采集卡驱动是否安装")
|
||
return
|
||
|
||
print("✅ 采集卡初始化成功")
|
||
print("按 'q' 或 ESC 键退出测试")
|
||
print()
|
||
|
||
try:
|
||
frame_count = 0
|
||
show_original = True # 默认同时显示原始帧和检测结果
|
||
|
||
print("提示: 按 'o' 键切换是否显示原始帧对比")
|
||
print()
|
||
|
||
while True:
|
||
# 获取帧
|
||
frame_data = get_image.get_frame()
|
||
|
||
if frame_data is None:
|
||
print("⚠️ 无法获取帧,跳过...")
|
||
continue
|
||
|
||
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
|
||
# im_opencv_rgb 已经是RGB格式(经过BGR2RGB转换)
|
||
im_opencv_rgb, im_PIL = frame_data
|
||
|
||
if im_PIL is None:
|
||
print("⚠️ PIL图像为空,跳过...")
|
||
continue
|
||
|
||
# 获取原始BGR帧(与test_capture_card.py保存的raw_frame.jpg一致)
|
||
raw_frame_bgr = None
|
||
if get_image.cap is not None and get_image.frame is not None:
|
||
raw_frame_bgr = get_image.frame.copy() # 原始BGR格式,未裁剪
|
||
|
||
# 初始化检测结果字典
|
||
detections = {
|
||
'center': None, 'next': None,
|
||
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
|
||
'boss': None, 'zhaozi': None,
|
||
'daojv': [], 'gw': []
|
||
}
|
||
|
||
# 执行YOLO检测
|
||
detections = yolo_shibie(im_PIL, im_opencv_rgb, raw_frame_bgr, detections, model, show_original)
|
||
|
||
# 检查按键
|
||
key = cv2.waitKey(1) & 0xFF
|
||
if key in [27, ord('q'), ord('Q')]:
|
||
print("\n用户退出")
|
||
break
|
||
elif key == ord('o') or key == ord('O'):
|
||
show_original = not show_original
|
||
print(f"切换显示模式: {'原始帧对比' if show_original else '仅检测结果'}")
|
||
|
||
frame_count += 1
|
||
if frame_count % 30 == 0: # 每30帧打印一次
|
||
print(f"📊 已处理 {frame_count} 帧")
|
||
# 打印有检测到的目标
|
||
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
|
||
if detected_items:
|
||
print(f" 检测到: {detected_items}")
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n用户中断测试")
|
||
except Exception as e:
|
||
print(f"\n❌ 测试过程中发生错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
# 清理资源
|
||
get_image.release()
|
||
cv2.destroyAllWindows()
|
||
print("🔚 测试结束")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|