""" 从main.py提取的YOLO识别测试文件 使用与main.py相同的识别逻辑 """ import cv2 from utils.get_image import GetImage from ultralytics import YOLO from config import config_manager import os # 检查模型文件是否存在 model_path = r"best.pt" model0_path = r"best0.pt" if not os.path.exists(model_path): print(f"❌ 模型文件不存在: {model_path}") exit(1) if not os.path.exists(model0_path): print(f"❌ 模型文件不存在: {model0_path}") exit(1) # 加载YOLO模型(与main.py保持一致) try: model = YOLO(model_path).to('cuda') model0 = YOLO(model0_path).to('cuda') print(f"✅ 模型加载成功: {model_path}") print(f"✅ 模型加载成功: {model0_path}") except Exception as e: print(f"❌ 模型加载失败: {e}") exit(1) def yolo_shibie(im_PIL, detections, model): """ YOLO识别函数(与main.py中的实现完全一致) :param im_PIL: PIL图像对象 :param detections: 检测结果字典 :param model: YOLO模型 :return: 更新后的detections字典 """ results = model(im_PIL) # 目标检测 for result in results: for i in range(len(result.boxes.xyxy)): left, top, right, bottom = result.boxes.xyxy[i] scalar_tensor = result.boxes.cls[i] value = scalar_tensor.item() label = result.names[int(value)] if label == 'center' or label == 'next' or label == 'boss' or label == 'zhaozi': player_x = int(left + (right - left) / 2) player_y = int(top + (bottom - top) / 2) + 30 RW = [player_x, player_y] detections[label] = RW elif label == 'daojv' or label == 'gw': player_x = int(left + (right - left) / 2) player_y = int(top + (bottom - top) / 2) + 30 RW = [player_x, player_y] detections[label].append(RW) elif label == 'npc1' or label == 'npc2' or label == 'npc3' or label == 'npc4': player_x = int(left + (right - left) / 2) player_y = int(bottom) + 30 RW = [player_x, player_y] detections[label] = RW return detections def main(): """主函数""" print("="*60) print("YOLO识别测试(main.py逻辑)") print("="*60) # 从配置加载采集卡设置 active_group = config_manager.get_active_group() if active_group is None: print("⚠️ 没有活动的配置组,使用默认设置") print("提示: 可以运行 python gui_config.py 设置配置") cam_index = 0 width = 1920 height = 1080 use_model = model # 默认使用model else: print(f"📋 使用配置组: {active_group['name']}") cam_index = active_group['camera_index'] width = active_group['camera_width'] height = active_group['camera_height'] use_model = model0 # 城镇中使用model0 print(f" 使用模型: model0 (best0.pt) - 用于城镇识别") print(f" 采集卡索引: {cam_index}") print(f" 分辨率: {width}x{height}") print() # 初始化采集卡 print("🔧 正在初始化采集卡...") get_image = GetImage( cam_index=cam_index, width=width, height=height ) if get_image.cap is None: print("❌ 采集卡初始化失败") print("请检查:") print("1. 采集卡是否正确连接") print("2. 采集卡索引是否正确") print("3. 采集卡驱动是否安装") return print("✅ 采集卡初始化成功") print("\n快捷键:") print(" 'q' 或 ESC - 退出") print(" 'm' - 切换模型 (model/model0)") print(" 'd' - 显示/隐藏检测信息") print() try: frame_count = 0 show_detections = True # 是否显示检测信息 current_model = use_model # 当前使用的模型 current_model_name = "model0" if use_model == model0 else "model" while True: # 获取帧 frame_data = get_image.get_frame() if frame_data is None: print("⚠️ 无法获取帧,跳过...") continue # frame_data 是 [im_opencv_rgb, im_PIL] 格式 im_opencv_rgb, im_PIL = frame_data if im_PIL is None: print("⚠️ PIL图像为空,跳过...") continue # 初始化检测结果字典 detections = { 'center': None, 'next': None, 'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None, 'boss': None, 'zhaozi': None, 'daojv': [], 'gw': [] } # 执行YOLO检测(使用main.py的逻辑) detections = yolo_shibie(im_PIL, detections, current_model) # 获取绘制好框的图像用于显示 try: results = current_model(im_PIL) result = results[0] frame_with_boxes_rgb = result.plot() frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR) except Exception as e: print(f"⚠️ 绘制检测框失败: {e}") frame_with_boxes_bgr = cv2.cvtColor(im_opencv_rgb, cv2.COLOR_RGB2BGR) # 在图像上显示检测信息 if show_detections: # 显示模型名称 cv2.putText(frame_with_boxes_bgr, f"Model: {current_model_name}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) # 显示检测到的目标 y_offset = 60 detected_items = [] for key, value in detections.items(): if value is not None and value != []: if key in ['daojv', 'gw']: detected_items.append(f"{key}: {len(value)}个") else: detected_items.append(f"{key}: {value}") if detected_items: text = f"Detected: {', '.join(detected_items[:5])}" # 最多显示5个 if len(detected_items) > 5: text += f" ... (+{len(detected_items)-5})" cv2.putText(frame_with_boxes_bgr, text, (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2) # 显示图像 cv2.imshow("YOLO Detection (main.py logic)", frame_with_boxes_bgr) # 检查按键 key = cv2.waitKey(1) & 0xFF if key in [27, ord('q'), ord('Q')]: print("\n用户退出") break elif key == ord('m') or key == ord('M'): # 切换模型 if current_model == model: current_model = model0 current_model_name = "model0" else: current_model = model current_model_name = "model" print(f"切换模型: {current_model_name}") elif key == ord('d') or key == ord('D'): show_detections = not show_detections print(f"显示检测信息: {'开启' if show_detections else '关闭'}") frame_count += 1 if frame_count % 30 == 0: # 每30帧打印一次 print(f"📊 已处理 {frame_count} 帧 (模型: {current_model_name})") # 打印有检测到的目标 detected_items = {k: v for k, v in detections.items() if v is not None and v != []} if detected_items: print(f" 检测到: {detected_items}") except KeyboardInterrupt: print("\n\n用户中断测试") except Exception as e: print(f"\n❌ 测试过程中发生错误: {e}") import traceback traceback.print_exc() finally: # 清理资源 get_image.release() cv2.destroyAllWindows() print("🔚 测试结束") if __name__ == "__main__": main()