测试文件提交
This commit is contained in:
241
yolo_test.py
241
yolo_test.py
@@ -1,55 +1,216 @@
|
||||
import cv2
|
||||
from utils.get_image import get_image
|
||||
from utils.get_image import GetImage
|
||||
from ultralytics import YOLO
|
||||
from config import config_manager
|
||||
import os
|
||||
|
||||
model = YOLO(r"best0.pt").to('cuda')
|
||||
# 检查模型文件是否存在
|
||||
model_path = r"best0.pt"
|
||||
if not os.path.exists(model_path):
|
||||
print(f"❌ 模型文件不存在: {model_path}")
|
||||
exit(1)
|
||||
|
||||
def yolo_shibie(im_PIL, detections):
|
||||
results = model(im_PIL)
|
||||
result = results[0]
|
||||
# 加载YOLO模型
|
||||
try:
|
||||
model = YOLO(model_path).to('cuda')
|
||||
print(f"✅ 模型加载成功: {model_path}")
|
||||
except Exception as e:
|
||||
print(f"❌ 模型加载失败: {e}")
|
||||
exit(1)
|
||||
|
||||
# ✅ 获取绘制好框的图像
|
||||
frame_with_boxes = result.plot()
|
||||
def yolo_shibie(im_PIL, im_opencv_rgb, raw_frame_bgr, detections, model, show_original=True):
|
||||
"""
|
||||
YOLO识别函数
|
||||
:param im_PIL: PIL图像对象
|
||||
:param im_opencv_rgb: RGB格式的OpenCV图像(裁剪后)
|
||||
:param raw_frame_bgr: 原始BGR格式的OpenCV图像(未裁剪,与raw_frame.jpg一致)
|
||||
:param detections: 检测结果字典
|
||||
:param model: YOLO模型
|
||||
:param show_original: 是否同时显示原始帧
|
||||
:return: 更新后的detections字典,如果用户退出则返回None
|
||||
"""
|
||||
if im_PIL is None:
|
||||
return detections
|
||||
|
||||
try:
|
||||
results = model(im_PIL)
|
||||
result = results[0]
|
||||
|
||||
# ✅ 用 OpenCV 动态显示
|
||||
cv2.imshow("YOLO实时检测", frame_with_boxes)
|
||||
# ✅ 获取绘制好框的图像(RGB格式)
|
||||
frame_with_boxes_rgb = result.plot()
|
||||
|
||||
# ✅ 转换为BGR格式用于OpenCV显示
|
||||
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
|
||||
|
||||
# ESC 或 Q 键退出
|
||||
if cv2.waitKey(1) & 0xFF in [27, ord('q')]:
|
||||
return None
|
||||
# 显示画面
|
||||
if show_original and raw_frame_bgr is not None:
|
||||
# 同时显示原始帧和检测结果(并排显示)
|
||||
# 调整原始帧大小以匹配裁剪后的检测结果
|
||||
h, w = frame_with_boxes_bgr.shape[:2]
|
||||
# 裁剪原始帧(与get_frame的处理一致:30:30+720, 0:1280)
|
||||
raw_height, raw_width = raw_frame_bgr.shape[:2]
|
||||
crop_top = 30
|
||||
crop_bottom = min(crop_top + h, raw_height)
|
||||
crop_right = min(w, raw_width)
|
||||
raw_cropped = raw_frame_bgr[crop_top:crop_bottom, 0:crop_right]
|
||||
|
||||
# 如果尺寸不匹配,调整原始帧大小
|
||||
if raw_cropped.shape[:2] != (h, w):
|
||||
raw_cropped = cv2.resize(raw_cropped, (w, h))
|
||||
|
||||
# 并排显示:原始帧(左) | 检测结果(右)
|
||||
# 原始帧已经是BGR格式,检测结果也是BGR格式,可以直接拼接
|
||||
combined = cv2.hconcat([raw_cropped, frame_with_boxes_bgr])
|
||||
cv2.imshow("原始帧 (左, 与raw_frame.jpg一致) | YOLO检测结果 (右)", combined)
|
||||
else:
|
||||
# 只显示检测结果
|
||||
cv2.imshow("YOLO实时检测", frame_with_boxes_bgr)
|
||||
|
||||
# ✅ 提取检测信息
|
||||
for i in range(len(result.boxes.xyxy)):
|
||||
left, top, right, bottom = result.boxes.xyxy[i]
|
||||
cls_id = int(result.boxes.cls[i])
|
||||
label = result.names[cls_id]
|
||||
# ✅ 提取检测信息
|
||||
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
||||
for i in range(len(result.boxes.xyxy)):
|
||||
try:
|
||||
left = float(result.boxes.xyxy[i][0])
|
||||
top = float(result.boxes.xyxy[i][1])
|
||||
right = float(result.boxes.xyxy[i][2])
|
||||
bottom = float(result.boxes.xyxy[i][3])
|
||||
cls_id = int(result.boxes.cls[i])
|
||||
label = result.names[cls_id]
|
||||
|
||||
if label in ['center', 'next', 'npc1', 'npc2', 'npc3', 'npc4', 'boss', 'zhaozi']:
|
||||
player_x = int(left + (right - left) / 2) + 3
|
||||
player_y = int(top + (bottom - top) / 2) + 40
|
||||
detections[label] = [player_x, player_y]
|
||||
elif label in ['daojv', 'gw']:
|
||||
player_x = int(left + (right - left) / 2) + 3
|
||||
player_y = int(top + (bottom - top) / 2) + 40
|
||||
detections[label].append([player_x, player_y])
|
||||
if label in ['center', 'next', 'npc1', 'npc2', 'npc3', 'npc4', 'boss', 'zhaozi']:
|
||||
player_x = int(left + (right - left) / 2) + 3
|
||||
player_y = int(top + (bottom - top) / 2) + 40
|
||||
detections[label] = [player_x, player_y]
|
||||
elif label in ['daojv', 'gw']:
|
||||
player_x = int(left + (right - left) / 2) + 3
|
||||
player_y = int(top + (bottom - top) / 2) + 40
|
||||
# 确保列表存在
|
||||
if label not in detections:
|
||||
detections[label] = []
|
||||
detections[label].append([player_x, player_y])
|
||||
except Exception as e:
|
||||
print(f"⚠️ 处理检测框时出错: {e}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ YOLO检测出错: {e}")
|
||||
|
||||
return detections
|
||||
|
||||
|
||||
while True:
|
||||
detections = {
|
||||
'center': None, 'next': None,
|
||||
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
|
||||
'boss': None, 'zhaozi': None,
|
||||
'daojv': [], 'gw': []
|
||||
}
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("="*60)
|
||||
print("YOLO实时检测测试")
|
||||
print("="*60)
|
||||
|
||||
# 从配置加载采集卡设置
|
||||
active_group = config_manager.get_active_group()
|
||||
|
||||
if active_group is None:
|
||||
print("⚠️ 没有活动的配置组,使用默认设置")
|
||||
print("提示: 可以运行 python gui_config.py 设置配置")
|
||||
cam_index = 0
|
||||
width = 1920
|
||||
height = 1080
|
||||
else:
|
||||
print(f"📋 使用配置组: {active_group['name']}")
|
||||
cam_index = active_group['camera_index']
|
||||
width = active_group['camera_width']
|
||||
height = active_group['camera_height']
|
||||
|
||||
print(f" 采集卡索引: {cam_index}")
|
||||
print(f" 分辨率: {width}x{height}")
|
||||
print()
|
||||
|
||||
# 初始化采集卡
|
||||
print("🔧 正在初始化采集卡...")
|
||||
get_image = GetImage(
|
||||
cam_index=cam_index,
|
||||
width=width,
|
||||
height=height
|
||||
)
|
||||
|
||||
if get_image.cap is None:
|
||||
print("❌ 采集卡初始化失败")
|
||||
print("请检查:")
|
||||
print("1. 采集卡是否正确连接")
|
||||
print("2. 采集卡索引是否正确")
|
||||
print("3. 采集卡驱动是否安装")
|
||||
return
|
||||
|
||||
print("✅ 采集卡初始化成功")
|
||||
print("按 'q' 或 ESC 键退出测试")
|
||||
print()
|
||||
|
||||
try:
|
||||
frame_count = 0
|
||||
show_original = True # 默认同时显示原始帧和检测结果
|
||||
|
||||
print("提示: 按 'o' 键切换是否显示原始帧对比")
|
||||
print()
|
||||
|
||||
while True:
|
||||
# 获取帧
|
||||
frame_data = get_image.get_frame()
|
||||
|
||||
if frame_data is None:
|
||||
print("⚠️ 无法获取帧,跳过...")
|
||||
continue
|
||||
|
||||
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
|
||||
# im_opencv_rgb 已经是RGB格式(经过BGR2RGB转换)
|
||||
im_opencv_rgb, im_PIL = frame_data
|
||||
|
||||
if im_PIL is None:
|
||||
print("⚠️ PIL图像为空,跳过...")
|
||||
continue
|
||||
|
||||
# 获取原始BGR帧(与test_capture_card.py保存的raw_frame.jpg一致)
|
||||
raw_frame_bgr = None
|
||||
if get_image.cap is not None and get_image.frame is not None:
|
||||
raw_frame_bgr = get_image.frame.copy() # 原始BGR格式,未裁剪
|
||||
|
||||
# 初始化检测结果字典
|
||||
detections = {
|
||||
'center': None, 'next': None,
|
||||
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
|
||||
'boss': None, 'zhaozi': None,
|
||||
'daojv': [], 'gw': []
|
||||
}
|
||||
|
||||
# 执行YOLO检测
|
||||
detections = yolo_shibie(im_PIL, im_opencv_rgb, raw_frame_bgr, detections, model, show_original)
|
||||
|
||||
# 检查按键
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
if key in [27, ord('q'), ord('Q')]:
|
||||
print("\n用户退出")
|
||||
break
|
||||
elif key == ord('o') or key == ord('O'):
|
||||
show_original = not show_original
|
||||
print(f"切换显示模式: {'原始帧对比' if show_original else '仅检测结果'}")
|
||||
|
||||
frame_count += 1
|
||||
if frame_count % 30 == 0: # 每30帧打印一次
|
||||
print(f"📊 已处理 {frame_count} 帧")
|
||||
# 打印有检测到的目标
|
||||
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
|
||||
if detected_items:
|
||||
print(f" 检测到: {detected_items}")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n用户中断测试")
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试过程中发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# 清理资源
|
||||
get_image.release()
|
||||
cv2.destroyAllWindows()
|
||||
print("🔚 测试结束")
|
||||
|
||||
im_opencv = get_image.get_frame() # [RGB, PIL]
|
||||
detections = yolo_shibie(im_opencv[1], detections)
|
||||
|
||||
if detections is None: # 用户退出
|
||||
break
|
||||
|
||||
print(detections)
|
||||
|
||||
cv2.destroyAllWindows()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user