Files
huojv/yolotest2.py
2025-11-04 14:47:28 +08:00

227 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
从main.py提取的YOLO识别测试文件
使用与main.py相同的识别逻辑
"""
import cv2
from utils.get_image import GetImage
from ultralytics import YOLO
from config import config_manager
import os
# 检查模型文件是否存在
model_path = r"best.pt"
model0_path = r"best0.pt"
if not os.path.exists(model_path):
print(f"❌ 模型文件不存在: {model_path}")
exit(1)
if not os.path.exists(model0_path):
print(f"❌ 模型文件不存在: {model0_path}")
exit(1)
# 加载YOLO模型与main.py保持一致
try:
model = YOLO(model_path).to('cuda')
model0 = YOLO(model0_path).to('cuda')
print(f"✅ 模型加载成功: {model_path}")
print(f"✅ 模型加载成功: {model0_path}")
except Exception as e:
print(f"❌ 模型加载失败: {e}")
exit(1)
def yolo_shibie(im_PIL, detections, model):
"""
YOLO识别函数与main.py中的实现完全一致
:param im_PIL: PIL图像对象
:param detections: 检测结果字典
:param model: YOLO模型
:return: 更新后的detections字典
"""
results = model(im_PIL) # 目标检测
for result in results:
for i in range(len(result.boxes.xyxy)):
left, top, right, bottom = result.boxes.xyxy[i]
scalar_tensor = result.boxes.cls[i]
value = scalar_tensor.item()
label = result.names[int(value)]
if label == 'center' or label == 'next' or label == 'boss' or label == 'zhaozi':
player_x = int(left + (right - left) / 2)
player_y = int(top + (bottom - top) / 2) + 30
RW = [player_x, player_y]
detections[label] = RW
elif label == 'daojv' or label == 'gw':
player_x = int(left + (right - left) / 2)
player_y = int(top + (bottom - top) / 2) + 30
RW = [player_x, player_y]
detections[label].append(RW)
elif label == 'npc1' or label == 'npc2' or label == 'npc3' or label == 'npc4':
player_x = int(left + (right - left) / 2)
player_y = int(bottom) + 30
RW = [player_x, player_y]
detections[label] = RW
return detections
def main():
"""主函数"""
print("="*60)
print("YOLO识别测试main.py逻辑")
print("="*60)
# 从配置加载采集卡设置
active_group = config_manager.get_active_group()
if active_group is None:
print("⚠️ 没有活动的配置组,使用默认设置")
print("提示: 可以运行 python gui_config.py 设置配置")
cam_index = 0
width = 1920
height = 1080
use_model = model # 默认使用model
else:
print(f"📋 使用配置组: {active_group['name']}")
cam_index = active_group['camera_index']
width = active_group['camera_width']
height = active_group['camera_height']
use_model = model0 # 城镇中使用model0
print(f" 使用模型: model0 (best0.pt) - 用于城镇识别")
print(f" 采集卡索引: {cam_index}")
print(f" 分辨率: {width}x{height}")
print()
# 初始化采集卡
print("🔧 正在初始化采集卡...")
get_image = GetImage(
cam_index=cam_index,
width=width,
height=height
)
if get_image.cap is None:
print("❌ 采集卡初始化失败")
print("请检查:")
print("1. 采集卡是否正确连接")
print("2. 采集卡索引是否正确")
print("3. 采集卡驱动是否安装")
return
print("✅ 采集卡初始化成功")
print("\n快捷键:")
print(" 'q' 或 ESC - 退出")
print(" 'm' - 切换模型 (model/model0)")
print(" 'd' - 显示/隐藏检测信息")
print()
try:
frame_count = 0
show_detections = True # 是否显示检测信息
current_model = use_model # 当前使用的模型
current_model_name = "model0" if use_model == model0 else "model"
while True:
# 获取帧
frame_data = get_image.get_frame()
if frame_data is None:
print("⚠️ 无法获取帧,跳过...")
continue
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
im_opencv_rgb, im_PIL = frame_data
if im_PIL is None:
print("⚠️ PIL图像为空跳过...")
continue
# 初始化检测结果字典
detections = {
'center': None, 'next': None,
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
'boss': None, 'zhaozi': None,
'daojv': [], 'gw': []
}
# 执行YOLO检测使用main.py的逻辑
detections = yolo_shibie(im_PIL, detections, current_model)
# 获取绘制好框的图像用于显示
try:
results = current_model(im_PIL)
result = results[0]
frame_with_boxes_rgb = result.plot()
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
except Exception as e:
print(f"⚠️ 绘制检测框失败: {e}")
frame_with_boxes_bgr = cv2.cvtColor(im_opencv_rgb, cv2.COLOR_RGB2BGR)
# 在图像上显示检测信息
if show_detections:
# 显示模型名称
cv2.putText(frame_with_boxes_bgr, f"Model: {current_model_name}",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
# 显示检测到的目标
y_offset = 60
detected_items = []
for key, value in detections.items():
if value is not None and value != []:
if key in ['daojv', 'gw']:
detected_items.append(f"{key}: {len(value)}")
else:
detected_items.append(f"{key}: {value}")
if detected_items:
text = f"Detected: {', '.join(detected_items[:5])}" # 最多显示5个
if len(detected_items) > 5:
text += f" ... (+{len(detected_items)-5})"
cv2.putText(frame_with_boxes_bgr, text,
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
# 显示图像
cv2.imshow("YOLO Detection (main.py logic)", frame_with_boxes_bgr)
# 检查按键
key = cv2.waitKey(1) & 0xFF
if key in [27, ord('q'), ord('Q')]:
print("\n用户退出")
break
elif key == ord('m') or key == ord('M'):
# 切换模型
if current_model == model:
current_model = model0
current_model_name = "model0"
else:
current_model = model
current_model_name = "model"
print(f"切换模型: {current_model_name}")
elif key == ord('d') or key == ord('D'):
show_detections = not show_detections
print(f"显示检测信息: {'开启' if show_detections else '关闭'}")
frame_count += 1
if frame_count % 30 == 0: # 每30帧打印一次
print(f"📊 已处理 {frame_count} 帧 (模型: {current_model_name})")
# 打印有检测到的目标
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
if detected_items:
print(f" 检测到: {detected_items}")
except KeyboardInterrupt:
print("\n\n用户中断测试")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
finally:
# 清理资源
get_image.release()
cv2.destroyAllWindows()
print("🔚 测试结束")
if __name__ == "__main__":
main()