Files
huojv/yolo_test.py
2025-11-04 14:47:28 +08:00

398 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import cv2
from utils.get_image import GetImage
from ultralytics import YOLO
from config import config_manager
from utils.logger import logger
import os
import numpy as np
# 检查模型文件是否存在
model_path = r"best0.pt"
if not os.path.exists(model_path):
print(f"❌ 模型文件不存在: {model_path}")
exit(1)
# 加载YOLO模型
try:
model = YOLO(model_path).to('cuda')
print(f"✅ 模型加载成功: {model_path}")
except Exception as e:
print(f"❌ 模型加载失败: {e}")
exit(1)
def enhance_sharpness(image, strength=1.5):
"""
增强图像锐度
:param image: 输入图像BGR格式
:param strength: 锐化强度1.0-3.0默认1.5
:return: 锐化后的图像
"""
# 创建锐化核
kernel = np.array([[-1, -1, -1],
[-1, 9*strength, -1],
[-1, -1, -1]]) / (9*strength - 8)
sharpened = cv2.filter2D(image, -1, kernel)
return sharpened
def enhance_contrast(image, alpha=1.2, beta=10):
"""
增强对比度和亮度
:param image: 输入图像
:param alpha: 对比度控制1.0-3.0默认1.2
:param beta: 亮度控制(-100到100默认10
:return: 增强后的图像
"""
return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
def denoise_image(image, method='bilateral'):
"""
去噪处理
:param image: 输入图像
:param method: 去噪方法 ('bilateral', 'gaussian', 'fastNlMeans')
:return: 去噪后的图像
"""
if method == 'bilateral':
# 双边滤波,保留边缘的同时去噪
return cv2.bilateralFilter(image, 9, 75, 75)
elif method == 'gaussian':
# 高斯模糊去噪
return cv2.GaussianBlur(image, (5, 5), 0)
elif method == 'fastNlMeans':
# 非局部均值去噪(效果最好但较慢)
return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
return image
def apply_enhancements(image, sharpness=True, contrast=True, denoise=True,
sharp_strength=1.5, contrast_alpha=1.2, contrast_beta=10,
denoise_method='bilateral'):
"""
应用所有图像增强
:param image: 输入图像BGR格式
:param sharpness: 是否锐化
:param contrast: 是否增强对比度
:param denoise: 是否去噪
:param sharp_strength: 锐化强度
:param contrast_alpha: 对比度系数
:param contrast_beta: 亮度调整
:param denoise_method: 去噪方法
:return: 增强后的图像
"""
enhanced = image.copy()
if denoise:
enhanced = denoise_image(enhanced, denoise_method)
if contrast:
enhanced = enhance_contrast(enhanced, contrast_alpha, contrast_beta)
if sharpness:
enhanced = enhance_sharpness(enhanced, sharp_strength)
return enhanced
def set_camera_properties(cap, brightness=None, contrast=None, saturation=None,
sharpness=None, gain=None, exposure=None):
"""
设置采集卡硬件参数
:param cap: VideoCapture对象
:param brightness: 亮度 (0-100)
:param contrast: 对比度 (0-100)
:param saturation: 饱和度 (0-100)
:param sharpness: 锐度 (0-100)
:param gain: 增益 (0-100)
:param exposure: 曝光 (通常为负值,如-6)
"""
props = {
cv2.CAP_PROP_BRIGHTNESS: brightness,
cv2.CAP_PROP_CONTRAST: contrast,
cv2.CAP_PROP_SATURATION: saturation,
cv2.CAP_PROP_SHARPNESS: sharpness,
cv2.CAP_PROP_GAIN: gain,
cv2.CAP_PROP_EXPOSURE: exposure,
}
for prop, value in props.items():
if value is not None:
try:
cap.set(prop, value)
actual = cap.get(prop)
logger.info(f" 设置 {prop.name if hasattr(prop, 'name') else prop}: {value} -> 实际: {actual:.2f}")
except Exception as e:
logger.warning(f" ⚠️ 设置参数 {prop} 失败: {e}")
def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params=None):
"""
YOLO识别函数
:param im_PIL: PIL图像对象
:param detections: 检测结果字典
:param model: YOLO模型
:param enhance_enabled: 是否启用图像增强
:param enhance_params: 图像增强参数
:return: 更新后的detections字典如果用户退出则返回None
"""
if im_PIL is None:
return detections
try:
results = model(im_PIL)
result = results[0]
# ✅ 获取绘制好框的图像RGB格式
frame_with_boxes_rgb = result.plot()
# ✅ 转换为BGR格式用于OpenCV显示
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
# 应用图像增强(如果启用)
display_frame = frame_with_boxes_bgr.copy()
if enhance_enabled and enhance_params:
try:
display_frame = apply_enhancements(display_frame, **enhance_params)
except Exception as e:
print(f"⚠️ 图像增强失败: {e}")
# 显示YOLO检测结果
cv2.imshow("YOLO Real-time Detection", display_frame)
# ✅ 提取检测信息
if result.boxes is not None and len(result.boxes.xyxy) > 0:
# 用于存储多个候选npc4如果检测到多个
npc4_candidates = []
for i in range(len(result.boxes.xyxy)):
try:
left = float(result.boxes.xyxy[i][0])
top = float(result.boxes.xyxy[i][1])
right = float(result.boxes.xyxy[i][2])
bottom = float(result.boxes.xyxy[i][3])
cls_id = int(result.boxes.cls[i])
label = result.names[cls_id]
# 获取置信度(如果可用)
confidence = float(result.boxes.conf[i]) if hasattr(result.boxes, 'conf') and len(result.boxes.conf) > i else 1.0
# npc1-npc4 使用底部位置与main.py保持一致
if label in ['npc1', 'npc2', 'npc3', 'npc4']:
player_x = int(left + (right - left) / 2)
player_y = int(bottom) + 30 # 使用底部位置与main.py保持一致
position = [player_x, player_y]
# 特殊处理npc4如果检测到多个收集所有候选
if label == 'npc4':
npc4_candidates.append({
'position': position,
'confidence': confidence,
'box': [left, top, right, bottom],
'area': (right - left) * (bottom - top) # 检测框面积
})
else:
# npc1-npc3直接赋值如果已经有值保留置信度更高的
if detections[label] is None or (hasattr(result.boxes, 'conf') and
confidence > 0.5):
detections[label] = position
# 其他目标使用中心点
elif label in ['center', 'next', 'boss', 'zhaozi']:
player_x = int(left + (right - left) / 2) + 3
player_y = int(top + (bottom - top) / 2) + 40
detections[label] = [player_x, player_y]
# 道具和怪物可以多个
elif label in ['daojv', 'gw']:
player_x = int(left + (right - left) / 2) + 3
player_y = int(top + (bottom - top) / 2) + 40
# 确保列表存在
if label not in detections:
detections[label] = []
detections[label].append([player_x, player_y])
except Exception as e:
print(f"⚠️ 处理检测框时出错: {e}")
continue
# 处理npc4如果检测到多个选择最合适的
if npc4_candidates:
# 按置信度排序,选择置信度最高的
npc4_candidates.sort(key=lambda x: x['confidence'], reverse=True)
# 选择最佳候选(置信度最高且面积合理)
best_npc4 = None
for candidate in npc4_candidates:
# 置信度阈值至少0.3(可根据实际情况调整)
if candidate['confidence'] >= 0.3:
# 检查检测框面积是否合理(避免过小的误检)
area = candidate['area']
if area > 100: # 最小面积阈值
best_npc4 = candidate
break
if best_npc4:
detections['npc4'] = best_npc4['position']
# 可选:输出调试信息
# print(f"✅ 检测到npc4: 位置={best_npc4['position']}, 置信度={best_npc4['confidence']:.2f}")
elif len(npc4_candidates) == 1:
# 如果只有一个候选,即使置信度较低也使用
detections['npc4'] = npc4_candidates[0]['position']
except Exception as e:
print(f"⚠️ YOLO检测出错: {e}")
return detections
def main():
"""主函数"""
print("="*60)
print("YOLO实时检测测试")
print("="*60)
# 从配置加载采集卡设置
active_group = config_manager.get_active_group()
if active_group is None:
print("⚠️ 没有活动的配置组,使用默认设置")
print("提示: 可以运行 python gui_config.py 设置配置")
cam_index = 0
width = 1920
height = 1080
else:
print(f"📋 使用配置组: {active_group['name']}")
cam_index = active_group['camera_index']
width = active_group['camera_width']
height = active_group['camera_height']
print(f" 采集卡索引: {cam_index}")
print(f" 分辨率: {width}x{height}")
print()
# 初始化采集卡
print("🔧 正在初始化采集卡...")
get_image = GetImage(
cam_index=cam_index,
width=width,
height=height
)
if get_image.cap is None:
print("❌ 采集卡初始化失败")
print("请检查:")
print("1. 采集卡是否正确连接")
print("2. 采集卡索引是否正确")
print("3. 采集卡驱动是否安装")
return
# 设置采集卡硬件参数以提高清晰度(可选)
print("\n🔧 设置采集卡参数以提高清晰度...")
print("提示: 可以根据实际情况调整这些参数")
set_camera_properties(
get_image.cap,
brightness=50, # 亮度 (0-100)
contrast=50, # 对比度 (0-100)
saturation=55, # 饱和度 (0-100)
sharpness=60, # 锐度 (0-100提高清晰度)
gain=None, # 增益 (根据实际情况调整)
exposure=None # 曝光 (根据实际情况调整,通常为负值)
)
print("✅ 采集卡初始化成功")
print("\n快捷键:")
print(" 'q' 或 ESC - 退出")
print(" 'e' - 切换图像增强")
print(" '1'/'2' - 调整锐化强度 (+/-0.1)")
print(" '3'/'4' - 调整对比度 (+/-0.1)")
print()
try:
frame_count = 0
enhance_enabled = False # 默认关闭图像增强
# 图像增强参数
enhance_params = {
'sharpness': True,
'contrast': True,
'denoise': True,
'sharp_strength': 1.5,
'contrast_alpha': 1.2,
'contrast_beta': 10,
'denoise_method': 'bilateral'
}
while True:
# 获取帧
frame_data = get_image.get_frame()
if frame_data is None:
print("⚠️ 无法获取帧,跳过...")
continue
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
# im_opencv_rgb 已经是RGB格式经过BGR2RGB转换
im_opencv_rgb, im_PIL = frame_data
if im_PIL is None:
print("⚠️ PIL图像为空跳过...")
continue
# 初始化检测结果字典
detections = {
'center': None, 'next': None,
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
'boss': None, 'zhaozi': None,
'daojv': [], 'gw': []
}
# 执行YOLO检测
detections = yolo_shibie(im_PIL, detections, model, enhance_enabled, enhance_params)
# 检查按键
key = cv2.waitKey(1) & 0xFF
if key in [27, ord('q'), ord('Q')]:
print("\n用户退出")
break
elif key == ord('e') or key == ord('E'):
enhance_enabled = not enhance_enabled
status = "开启" if enhance_enabled else "关闭"
print(f"图像增强: {status} (锐化={enhance_params['sharp_strength']:.1f}, "
f"对比度={enhance_params['contrast_alpha']:.1f})")
elif key == ord('1'):
enhance_params['sharp_strength'] = min(3.0, enhance_params['sharp_strength'] + 0.1)
print(f"锐化强度: {enhance_params['sharp_strength']:.1f}")
elif key == ord('2'):
enhance_params['sharp_strength'] = max(0.5, enhance_params['sharp_strength'] - 0.1)
print(f"锐化强度: {enhance_params['sharp_strength']:.1f}")
elif key == ord('3'):
enhance_params['contrast_alpha'] = min(3.0, enhance_params['contrast_alpha'] + 0.1)
print(f"对比度: {enhance_params['contrast_alpha']:.1f}")
elif key == ord('4'):
enhance_params['contrast_alpha'] = max(0.5, enhance_params['contrast_alpha'] - 0.1)
print(f"对比度: {enhance_params['contrast_alpha']:.1f}")
frame_count += 1
if frame_count % 30 == 0: # 每30帧打印一次
print(f"📊 已处理 {frame_count}")
# 打印有检测到的目标
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
if detected_items:
print(f" 检测到: {detected_items}")
except KeyboardInterrupt:
print("\n\n用户中断测试")
except Exception as e:
print(f"\n❌ 测试过程中发生错误: {e}")
import traceback
traceback.print_exc()
finally:
# 清理资源
get_image.release()
cv2.destroyAllWindows()
print("🔚 测试结束")
if __name__ == "__main__":
main()