diff --git a/yolo_test.py b/yolo_test.py index 44af6bb..6bb7a72 100644 --- a/yolo_test.py +++ b/yolo_test.py @@ -161,6 +161,9 @@ def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params # ✅ 提取检测信息 if result.boxes is not None and len(result.boxes.xyxy) > 0: + # 用于存储多个候选npc4(如果检测到多个) + npc4_candidates = [] + for i in range(len(result.boxes.xyxy)): try: left = float(result.boxes.xyxy[i][0]) @@ -169,11 +172,37 @@ def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params bottom = float(result.boxes.xyxy[i][3]) cls_id = int(result.boxes.cls[i]) label = result.names[cls_id] + + # 获取置信度(如果可用) + confidence = float(result.boxes.conf[i]) if hasattr(result.boxes, 'conf') and len(result.boxes.conf) > i else 1.0 - if label in ['center', 'next', 'npc1', 'npc2', 'npc3', 'npc4', 'boss', 'zhaozi']: + # npc1-npc4 使用底部位置(与main.py保持一致) + if label in ['npc1', 'npc2', 'npc3', 'npc4']: + player_x = int(left + (right - left) / 2) + player_y = int(bottom) + 30 # 使用底部位置,与main.py保持一致 + position = [player_x, player_y] + + # 特殊处理npc4:如果检测到多个,收集所有候选 + if label == 'npc4': + npc4_candidates.append({ + 'position': position, + 'confidence': confidence, + 'box': [left, top, right, bottom], + 'area': (right - left) * (bottom - top) # 检测框面积 + }) + else: + # npc1-npc3直接赋值(如果已经有值,保留置信度更高的) + if detections[label] is None or (hasattr(result.boxes, 'conf') and + confidence > 0.5): + detections[label] = position + + # 其他目标使用中心点 + elif label in ['center', 'next', 'boss', 'zhaozi']: player_x = int(left + (right - left) / 2) + 3 player_y = int(top + (bottom - top) / 2) + 40 detections[label] = [player_x, player_y] + + # 道具和怪物可以多个 elif label in ['daojv', 'gw']: player_x = int(left + (right - left) / 2) + 3 player_y = int(top + (bottom - top) / 2) + 40 @@ -181,9 +210,34 @@ def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params if label not in detections: detections[label] = [] detections[label].append([player_x, player_y]) + except Exception as e: print(f"⚠️ 处理检测框时出错: {e}") continue + + # 处理npc4:如果检测到多个,选择最合适的 + if npc4_candidates: + # 按置信度排序,选择置信度最高的 + npc4_candidates.sort(key=lambda x: x['confidence'], reverse=True) + + # 选择最佳候选(置信度最高且面积合理) + best_npc4 = None + for candidate in npc4_candidates: + # 置信度阈值:至少0.3(可根据实际情况调整) + if candidate['confidence'] >= 0.3: + # 检查检测框面积是否合理(避免过小的误检) + area = candidate['area'] + if area > 100: # 最小面积阈值 + best_npc4 = candidate + break + + if best_npc4: + detections['npc4'] = best_npc4['position'] + # 可选:输出调试信息 + # print(f"✅ 检测到npc4: 位置={best_npc4['position']}, 置信度={best_npc4['confidence']:.2f}") + elif len(npc4_candidates) == 1: + # 如果只有一个候选,即使置信度较低也使用 + detections['npc4'] = npc4_candidates[0]['position'] except Exception as e: print(f"⚠️ YOLO检测出错: {e}") diff --git a/yolotest2.py b/yolotest2.py new file mode 100644 index 0000000..22f40cd --- /dev/null +++ b/yolotest2.py @@ -0,0 +1,226 @@ +""" +从main.py提取的YOLO识别测试文件 +使用与main.py相同的识别逻辑 +""" +import cv2 +from utils.get_image import GetImage +from ultralytics import YOLO +from config import config_manager +import os + +# 检查模型文件是否存在 +model_path = r"best.pt" +model0_path = r"best0.pt" + +if not os.path.exists(model_path): + print(f"❌ 模型文件不存在: {model_path}") + exit(1) +if not os.path.exists(model0_path): + print(f"❌ 模型文件不存在: {model0_path}") + exit(1) + +# 加载YOLO模型(与main.py保持一致) +try: + model = YOLO(model_path).to('cuda') + model0 = YOLO(model0_path).to('cuda') + print(f"✅ 模型加载成功: {model_path}") + print(f"✅ 模型加载成功: {model0_path}") +except Exception as e: + print(f"❌ 模型加载失败: {e}") + exit(1) + + +def yolo_shibie(im_PIL, detections, model): + """ + YOLO识别函数(与main.py中的实现完全一致) + :param im_PIL: PIL图像对象 + :param detections: 检测结果字典 + :param model: YOLO模型 + :return: 更新后的detections字典 + """ + results = model(im_PIL) # 目标检测 + for result in results: + for i in range(len(result.boxes.xyxy)): + left, top, right, bottom = result.boxes.xyxy[i] + scalar_tensor = result.boxes.cls[i] + value = scalar_tensor.item() + label = result.names[int(value)] + if label == 'center' or label == 'next' or label == 'boss' or label == 'zhaozi': + player_x = int(left + (right - left) / 2) + player_y = int(top + (bottom - top) / 2) + 30 + RW = [player_x, player_y] + detections[label] = RW + elif label == 'daojv' or label == 'gw': + player_x = int(left + (right - left) / 2) + player_y = int(top + (bottom - top) / 2) + 30 + RW = [player_x, player_y] + detections[label].append(RW) + elif label == 'npc1' or label == 'npc2' or label == 'npc3' or label == 'npc4': + player_x = int(left + (right - left) / 2) + player_y = int(bottom) + 30 + RW = [player_x, player_y] + detections[label] = RW + return detections + + +def main(): + """主函数""" + print("="*60) + print("YOLO识别测试(main.py逻辑)") + print("="*60) + + # 从配置加载采集卡设置 + active_group = config_manager.get_active_group() + + if active_group is None: + print("⚠️ 没有活动的配置组,使用默认设置") + print("提示: 可以运行 python gui_config.py 设置配置") + cam_index = 0 + width = 1920 + height = 1080 + use_model = model # 默认使用model + else: + print(f"📋 使用配置组: {active_group['name']}") + cam_index = active_group['camera_index'] + width = active_group['camera_width'] + height = active_group['camera_height'] + use_model = model0 # 城镇中使用model0 + print(f" 使用模型: model0 (best0.pt) - 用于城镇识别") + + print(f" 采集卡索引: {cam_index}") + print(f" 分辨率: {width}x{height}") + print() + + # 初始化采集卡 + print("🔧 正在初始化采集卡...") + get_image = GetImage( + cam_index=cam_index, + width=width, + height=height + ) + + if get_image.cap is None: + print("❌ 采集卡初始化失败") + print("请检查:") + print("1. 采集卡是否正确连接") + print("2. 采集卡索引是否正确") + print("3. 采集卡驱动是否安装") + return + + print("✅ 采集卡初始化成功") + print("\n快捷键:") + print(" 'q' 或 ESC - 退出") + print(" 'm' - 切换模型 (model/model0)") + print(" 'd' - 显示/隐藏检测信息") + print() + + try: + frame_count = 0 + show_detections = True # 是否显示检测信息 + current_model = use_model # 当前使用的模型 + current_model_name = "model0" if use_model == model0 else "model" + + while True: + # 获取帧 + frame_data = get_image.get_frame() + + if frame_data is None: + print("⚠️ 无法获取帧,跳过...") + continue + + # frame_data 是 [im_opencv_rgb, im_PIL] 格式 + im_opencv_rgb, im_PIL = frame_data + + if im_PIL is None: + print("⚠️ PIL图像为空,跳过...") + continue + + # 初始化检测结果字典 + detections = { + 'center': None, 'next': None, + 'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None, + 'boss': None, 'zhaozi': None, + 'daojv': [], 'gw': [] + } + + # 执行YOLO检测(使用main.py的逻辑) + detections = yolo_shibie(im_PIL, detections, current_model) + + # 获取绘制好框的图像用于显示 + try: + results = current_model(im_PIL) + result = results[0] + frame_with_boxes_rgb = result.plot() + frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR) + except Exception as e: + print(f"⚠️ 绘制检测框失败: {e}") + frame_with_boxes_bgr = cv2.cvtColor(im_opencv_rgb, cv2.COLOR_RGB2BGR) + + # 在图像上显示检测信息 + if show_detections: + # 显示模型名称 + cv2.putText(frame_with_boxes_bgr, f"Model: {current_model_name}", + (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2) + + # 显示检测到的目标 + y_offset = 60 + detected_items = [] + for key, value in detections.items(): + if value is not None and value != []: + if key in ['daojv', 'gw']: + detected_items.append(f"{key}: {len(value)}个") + else: + detected_items.append(f"{key}: {value}") + + if detected_items: + text = f"Detected: {', '.join(detected_items[:5])}" # 最多显示5个 + if len(detected_items) > 5: + text += f" ... (+{len(detected_items)-5})" + cv2.putText(frame_with_boxes_bgr, text, + (10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2) + + # 显示图像 + cv2.imshow("YOLO Detection (main.py logic)", frame_with_boxes_bgr) + + # 检查按键 + key = cv2.waitKey(1) & 0xFF + if key in [27, ord('q'), ord('Q')]: + print("\n用户退出") + break + elif key == ord('m') or key == ord('M'): + # 切换模型 + if current_model == model: + current_model = model0 + current_model_name = "model0" + else: + current_model = model + current_model_name = "model" + print(f"切换模型: {current_model_name}") + elif key == ord('d') or key == ord('D'): + show_detections = not show_detections + print(f"显示检测信息: {'开启' if show_detections else '关闭'}") + + frame_count += 1 + if frame_count % 30 == 0: # 每30帧打印一次 + print(f"📊 已处理 {frame_count} 帧 (模型: {current_model_name})") + # 打印有检测到的目标 + detected_items = {k: v for k, v in detections.items() if v is not None and v != []} + if detected_items: + print(f" 检测到: {detected_items}") + + except KeyboardInterrupt: + print("\n\n用户中断测试") + except Exception as e: + print(f"\n❌ 测试过程中发生错误: {e}") + import traceback + traceback.print_exc() + finally: + # 清理资源 + get_image.release() + cv2.destroyAllWindows() + print("🔚 测试结束") + + +if __name__ == "__main__": + main() +