398 lines
15 KiB
Python
398 lines
15 KiB
Python
import cv2
|
||
from utils.get_image import GetImage
|
||
from ultralytics import YOLO
|
||
from config import config_manager
|
||
from utils.logger import logger
|
||
import os
|
||
import numpy as np
|
||
|
||
# 检查模型文件是否存在
|
||
model_path = r"best0.pt"
|
||
if not os.path.exists(model_path):
|
||
print(f"❌ 模型文件不存在: {model_path}")
|
||
exit(1)
|
||
|
||
# 加载YOLO模型
|
||
try:
|
||
model = YOLO(model_path).to('cuda')
|
||
print(f"✅ 模型加载成功: {model_path}")
|
||
except Exception as e:
|
||
print(f"❌ 模型加载失败: {e}")
|
||
exit(1)
|
||
|
||
def enhance_sharpness(image, strength=1.5):
|
||
"""
|
||
增强图像锐度
|
||
:param image: 输入图像(BGR格式)
|
||
:param strength: 锐化强度(1.0-3.0,默认1.5)
|
||
:return: 锐化后的图像
|
||
"""
|
||
# 创建锐化核
|
||
kernel = np.array([[-1, -1, -1],
|
||
[-1, 9*strength, -1],
|
||
[-1, -1, -1]]) / (9*strength - 8)
|
||
sharpened = cv2.filter2D(image, -1, kernel)
|
||
return sharpened
|
||
|
||
|
||
def enhance_contrast(image, alpha=1.2, beta=10):
|
||
"""
|
||
增强对比度和亮度
|
||
:param image: 输入图像
|
||
:param alpha: 对比度控制(1.0-3.0,默认1.2)
|
||
:param beta: 亮度控制(-100到100,默认10)
|
||
:return: 增强后的图像
|
||
"""
|
||
return cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
|
||
|
||
|
||
def denoise_image(image, method='bilateral'):
|
||
"""
|
||
去噪处理
|
||
:param image: 输入图像
|
||
:param method: 去噪方法 ('bilateral', 'gaussian', 'fastNlMeans')
|
||
:return: 去噪后的图像
|
||
"""
|
||
if method == 'bilateral':
|
||
# 双边滤波,保留边缘的同时去噪
|
||
return cv2.bilateralFilter(image, 9, 75, 75)
|
||
elif method == 'gaussian':
|
||
# 高斯模糊去噪
|
||
return cv2.GaussianBlur(image, (5, 5), 0)
|
||
elif method == 'fastNlMeans':
|
||
# 非局部均值去噪(效果最好但较慢)
|
||
return cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
|
||
return image
|
||
|
||
|
||
def apply_enhancements(image, sharpness=True, contrast=True, denoise=True,
|
||
sharp_strength=1.5, contrast_alpha=1.2, contrast_beta=10,
|
||
denoise_method='bilateral'):
|
||
"""
|
||
应用所有图像增强
|
||
:param image: 输入图像(BGR格式)
|
||
:param sharpness: 是否锐化
|
||
:param contrast: 是否增强对比度
|
||
:param denoise: 是否去噪
|
||
:param sharp_strength: 锐化强度
|
||
:param contrast_alpha: 对比度系数
|
||
:param contrast_beta: 亮度调整
|
||
:param denoise_method: 去噪方法
|
||
:return: 增强后的图像
|
||
"""
|
||
enhanced = image.copy()
|
||
|
||
if denoise:
|
||
enhanced = denoise_image(enhanced, denoise_method)
|
||
|
||
if contrast:
|
||
enhanced = enhance_contrast(enhanced, contrast_alpha, contrast_beta)
|
||
|
||
if sharpness:
|
||
enhanced = enhance_sharpness(enhanced, sharp_strength)
|
||
|
||
return enhanced
|
||
|
||
|
||
def set_camera_properties(cap, brightness=None, contrast=None, saturation=None,
|
||
sharpness=None, gain=None, exposure=None):
|
||
"""
|
||
设置采集卡硬件参数
|
||
:param cap: VideoCapture对象
|
||
:param brightness: 亮度 (0-100)
|
||
:param contrast: 对比度 (0-100)
|
||
:param saturation: 饱和度 (0-100)
|
||
:param sharpness: 锐度 (0-100)
|
||
:param gain: 增益 (0-100)
|
||
:param exposure: 曝光 (通常为负值,如-6)
|
||
"""
|
||
props = {
|
||
cv2.CAP_PROP_BRIGHTNESS: brightness,
|
||
cv2.CAP_PROP_CONTRAST: contrast,
|
||
cv2.CAP_PROP_SATURATION: saturation,
|
||
cv2.CAP_PROP_SHARPNESS: sharpness,
|
||
cv2.CAP_PROP_GAIN: gain,
|
||
cv2.CAP_PROP_EXPOSURE: exposure,
|
||
}
|
||
|
||
for prop, value in props.items():
|
||
if value is not None:
|
||
try:
|
||
cap.set(prop, value)
|
||
actual = cap.get(prop)
|
||
logger.info(f" 设置 {prop.name if hasattr(prop, 'name') else prop}: {value} -> 实际: {actual:.2f}")
|
||
except Exception as e:
|
||
logger.warning(f" ⚠️ 设置参数 {prop} 失败: {e}")
|
||
|
||
|
||
def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params=None):
|
||
"""
|
||
YOLO识别函数
|
||
:param im_PIL: PIL图像对象
|
||
:param detections: 检测结果字典
|
||
:param model: YOLO模型
|
||
:param enhance_enabled: 是否启用图像增强
|
||
:param enhance_params: 图像增强参数
|
||
:return: 更新后的detections字典,如果用户退出则返回None
|
||
"""
|
||
if im_PIL is None:
|
||
return detections
|
||
|
||
try:
|
||
results = model(im_PIL)
|
||
result = results[0]
|
||
|
||
# ✅ 获取绘制好框的图像(RGB格式)
|
||
frame_with_boxes_rgb = result.plot()
|
||
|
||
# ✅ 转换为BGR格式用于OpenCV显示
|
||
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
|
||
|
||
# 应用图像增强(如果启用)
|
||
display_frame = frame_with_boxes_bgr.copy()
|
||
if enhance_enabled and enhance_params:
|
||
try:
|
||
display_frame = apply_enhancements(display_frame, **enhance_params)
|
||
except Exception as e:
|
||
print(f"⚠️ 图像增强失败: {e}")
|
||
|
||
# 显示YOLO检测结果
|
||
cv2.imshow("YOLO Real-time Detection", display_frame)
|
||
|
||
# ✅ 提取检测信息
|
||
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
||
# 用于存储多个候选npc4(如果检测到多个)
|
||
npc4_candidates = []
|
||
|
||
for i in range(len(result.boxes.xyxy)):
|
||
try:
|
||
left = float(result.boxes.xyxy[i][0])
|
||
top = float(result.boxes.xyxy[i][1])
|
||
right = float(result.boxes.xyxy[i][2])
|
||
bottom = float(result.boxes.xyxy[i][3])
|
||
cls_id = int(result.boxes.cls[i])
|
||
label = result.names[cls_id]
|
||
|
||
# 获取置信度(如果可用)
|
||
confidence = float(result.boxes.conf[i]) if hasattr(result.boxes, 'conf') and len(result.boxes.conf) > i else 1.0
|
||
|
||
# npc1-npc4 使用底部位置(与main.py保持一致)
|
||
if label in ['npc1', 'npc2', 'npc3', 'npc4']:
|
||
player_x = int(left + (right - left) / 2)
|
||
player_y = int(bottom) + 30 # 使用底部位置,与main.py保持一致
|
||
position = [player_x, player_y]
|
||
|
||
# 特殊处理npc4:如果检测到多个,收集所有候选
|
||
if label == 'npc4':
|
||
npc4_candidates.append({
|
||
'position': position,
|
||
'confidence': confidence,
|
||
'box': [left, top, right, bottom],
|
||
'area': (right - left) * (bottom - top) # 检测框面积
|
||
})
|
||
else:
|
||
# npc1-npc3直接赋值(如果已经有值,保留置信度更高的)
|
||
if detections[label] is None or (hasattr(result.boxes, 'conf') and
|
||
confidence > 0.5):
|
||
detections[label] = position
|
||
|
||
# 其他目标使用中心点
|
||
elif label in ['center', 'next', 'boss', 'zhaozi']:
|
||
player_x = int(left + (right - left) / 2) + 3
|
||
player_y = int(top + (bottom - top) / 2) + 40
|
||
detections[label] = [player_x, player_y]
|
||
|
||
# 道具和怪物可以多个
|
||
elif label in ['daojv', 'gw']:
|
||
player_x = int(left + (right - left) / 2) + 3
|
||
player_y = int(top + (bottom - top) / 2) + 40
|
||
# 确保列表存在
|
||
if label not in detections:
|
||
detections[label] = []
|
||
detections[label].append([player_x, player_y])
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ 处理检测框时出错: {e}")
|
||
continue
|
||
|
||
# 处理npc4:如果检测到多个,选择最合适的
|
||
if npc4_candidates:
|
||
# 按置信度排序,选择置信度最高的
|
||
npc4_candidates.sort(key=lambda x: x['confidence'], reverse=True)
|
||
|
||
# 选择最佳候选(置信度最高且面积合理)
|
||
best_npc4 = None
|
||
for candidate in npc4_candidates:
|
||
# 置信度阈值:至少0.3(可根据实际情况调整)
|
||
if candidate['confidence'] >= 0.3:
|
||
# 检查检测框面积是否合理(避免过小的误检)
|
||
area = candidate['area']
|
||
if area > 100: # 最小面积阈值
|
||
best_npc4 = candidate
|
||
break
|
||
|
||
if best_npc4:
|
||
detections['npc4'] = best_npc4['position']
|
||
# 可选:输出调试信息
|
||
# print(f"✅ 检测到npc4: 位置={best_npc4['position']}, 置信度={best_npc4['confidence']:.2f}")
|
||
elif len(npc4_candidates) == 1:
|
||
# 如果只有一个候选,即使置信度较低也使用
|
||
detections['npc4'] = npc4_candidates[0]['position']
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ YOLO检测出错: {e}")
|
||
|
||
return detections
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("="*60)
|
||
print("YOLO实时检测测试")
|
||
print("="*60)
|
||
|
||
# 从配置加载采集卡设置
|
||
active_group = config_manager.get_active_group()
|
||
|
||
if active_group is None:
|
||
print("⚠️ 没有活动的配置组,使用默认设置")
|
||
print("提示: 可以运行 python gui_config.py 设置配置")
|
||
cam_index = 0
|
||
width = 1920
|
||
height = 1080
|
||
else:
|
||
print(f"📋 使用配置组: {active_group['name']}")
|
||
cam_index = active_group['camera_index']
|
||
width = active_group['camera_width']
|
||
height = active_group['camera_height']
|
||
|
||
print(f" 采集卡索引: {cam_index}")
|
||
print(f" 分辨率: {width}x{height}")
|
||
print()
|
||
|
||
# 初始化采集卡
|
||
print("🔧 正在初始化采集卡...")
|
||
get_image = GetImage(
|
||
cam_index=cam_index,
|
||
width=width,
|
||
height=height
|
||
)
|
||
|
||
if get_image.cap is None:
|
||
print("❌ 采集卡初始化失败")
|
||
print("请检查:")
|
||
print("1. 采集卡是否正确连接")
|
||
print("2. 采集卡索引是否正确")
|
||
print("3. 采集卡驱动是否安装")
|
||
return
|
||
|
||
# 设置采集卡硬件参数以提高清晰度(可选)
|
||
print("\n🔧 设置采集卡参数以提高清晰度...")
|
||
print("提示: 可以根据实际情况调整这些参数")
|
||
set_camera_properties(
|
||
get_image.cap,
|
||
brightness=50, # 亮度 (0-100)
|
||
contrast=50, # 对比度 (0-100)
|
||
saturation=55, # 饱和度 (0-100)
|
||
sharpness=60, # 锐度 (0-100,提高清晰度)
|
||
gain=None, # 增益 (根据实际情况调整)
|
||
exposure=None # 曝光 (根据实际情况调整,通常为负值)
|
||
)
|
||
|
||
print("✅ 采集卡初始化成功")
|
||
print("\n快捷键:")
|
||
print(" 'q' 或 ESC - 退出")
|
||
print(" 'e' - 切换图像增强")
|
||
print(" '1'/'2' - 调整锐化强度 (+/-0.1)")
|
||
print(" '3'/'4' - 调整对比度 (+/-0.1)")
|
||
print()
|
||
|
||
try:
|
||
frame_count = 0
|
||
enhance_enabled = False # 默认关闭图像增强
|
||
|
||
# 图像增强参数
|
||
enhance_params = {
|
||
'sharpness': True,
|
||
'contrast': True,
|
||
'denoise': True,
|
||
'sharp_strength': 1.5,
|
||
'contrast_alpha': 1.2,
|
||
'contrast_beta': 10,
|
||
'denoise_method': 'bilateral'
|
||
}
|
||
|
||
while True:
|
||
# 获取帧
|
||
frame_data = get_image.get_frame()
|
||
|
||
if frame_data is None:
|
||
print("⚠️ 无法获取帧,跳过...")
|
||
continue
|
||
|
||
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
|
||
# im_opencv_rgb 已经是RGB格式(经过BGR2RGB转换)
|
||
im_opencv_rgb, im_PIL = frame_data
|
||
|
||
if im_PIL is None:
|
||
print("⚠️ PIL图像为空,跳过...")
|
||
continue
|
||
|
||
# 初始化检测结果字典
|
||
detections = {
|
||
'center': None, 'next': None,
|
||
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
|
||
'boss': None, 'zhaozi': None,
|
||
'daojv': [], 'gw': []
|
||
}
|
||
|
||
# 执行YOLO检测
|
||
detections = yolo_shibie(im_PIL, detections, model, enhance_enabled, enhance_params)
|
||
|
||
# 检查按键
|
||
key = cv2.waitKey(1) & 0xFF
|
||
if key in [27, ord('q'), ord('Q')]:
|
||
print("\n用户退出")
|
||
break
|
||
elif key == ord('e') or key == ord('E'):
|
||
enhance_enabled = not enhance_enabled
|
||
status = "开启" if enhance_enabled else "关闭"
|
||
print(f"图像增强: {status} (锐化={enhance_params['sharp_strength']:.1f}, "
|
||
f"对比度={enhance_params['contrast_alpha']:.1f})")
|
||
elif key == ord('1'):
|
||
enhance_params['sharp_strength'] = min(3.0, enhance_params['sharp_strength'] + 0.1)
|
||
print(f"锐化强度: {enhance_params['sharp_strength']:.1f}")
|
||
elif key == ord('2'):
|
||
enhance_params['sharp_strength'] = max(0.5, enhance_params['sharp_strength'] - 0.1)
|
||
print(f"锐化强度: {enhance_params['sharp_strength']:.1f}")
|
||
elif key == ord('3'):
|
||
enhance_params['contrast_alpha'] = min(3.0, enhance_params['contrast_alpha'] + 0.1)
|
||
print(f"对比度: {enhance_params['contrast_alpha']:.1f}")
|
||
elif key == ord('4'):
|
||
enhance_params['contrast_alpha'] = max(0.5, enhance_params['contrast_alpha'] - 0.1)
|
||
print(f"对比度: {enhance_params['contrast_alpha']:.1f}")
|
||
|
||
frame_count += 1
|
||
if frame_count % 30 == 0: # 每30帧打印一次
|
||
print(f"📊 已处理 {frame_count} 帧")
|
||
# 打印有检测到的目标
|
||
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
|
||
if detected_items:
|
||
print(f" 检测到: {detected_items}")
|
||
|
||
except KeyboardInterrupt:
|
||
print("\n\n用户中断测试")
|
||
except Exception as e:
|
||
print(f"\n❌ 测试过程中发生错误: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
finally:
|
||
# 清理资源
|
||
get_image.release()
|
||
cv2.destroyAllWindows()
|
||
print("🔚 测试结束")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|