测试文件提交
This commit is contained in:
56
yolo_test.py
56
yolo_test.py
@@ -161,6 +161,9 @@ def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params
|
|||||||
|
|
||||||
# ✅ 提取检测信息
|
# ✅ 提取检测信息
|
||||||
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
if result.boxes is not None and len(result.boxes.xyxy) > 0:
|
||||||
|
# 用于存储多个候选npc4(如果检测到多个)
|
||||||
|
npc4_candidates = []
|
||||||
|
|
||||||
for i in range(len(result.boxes.xyxy)):
|
for i in range(len(result.boxes.xyxy)):
|
||||||
try:
|
try:
|
||||||
left = float(result.boxes.xyxy[i][0])
|
left = float(result.boxes.xyxy[i][0])
|
||||||
@@ -169,11 +172,37 @@ def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params
|
|||||||
bottom = float(result.boxes.xyxy[i][3])
|
bottom = float(result.boxes.xyxy[i][3])
|
||||||
cls_id = int(result.boxes.cls[i])
|
cls_id = int(result.boxes.cls[i])
|
||||||
label = result.names[cls_id]
|
label = result.names[cls_id]
|
||||||
|
|
||||||
|
# 获取置信度(如果可用)
|
||||||
|
confidence = float(result.boxes.conf[i]) if hasattr(result.boxes, 'conf') and len(result.boxes.conf) > i else 1.0
|
||||||
|
|
||||||
if label in ['center', 'next', 'npc1', 'npc2', 'npc3', 'npc4', 'boss', 'zhaozi']:
|
# npc1-npc4 使用底部位置(与main.py保持一致)
|
||||||
|
if label in ['npc1', 'npc2', 'npc3', 'npc4']:
|
||||||
|
player_x = int(left + (right - left) / 2)
|
||||||
|
player_y = int(bottom) + 30 # 使用底部位置,与main.py保持一致
|
||||||
|
position = [player_x, player_y]
|
||||||
|
|
||||||
|
# 特殊处理npc4:如果检测到多个,收集所有候选
|
||||||
|
if label == 'npc4':
|
||||||
|
npc4_candidates.append({
|
||||||
|
'position': position,
|
||||||
|
'confidence': confidence,
|
||||||
|
'box': [left, top, right, bottom],
|
||||||
|
'area': (right - left) * (bottom - top) # 检测框面积
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
# npc1-npc3直接赋值(如果已经有值,保留置信度更高的)
|
||||||
|
if detections[label] is None or (hasattr(result.boxes, 'conf') and
|
||||||
|
confidence > 0.5):
|
||||||
|
detections[label] = position
|
||||||
|
|
||||||
|
# 其他目标使用中心点
|
||||||
|
elif label in ['center', 'next', 'boss', 'zhaozi']:
|
||||||
player_x = int(left + (right - left) / 2) + 3
|
player_x = int(left + (right - left) / 2) + 3
|
||||||
player_y = int(top + (bottom - top) / 2) + 40
|
player_y = int(top + (bottom - top) / 2) + 40
|
||||||
detections[label] = [player_x, player_y]
|
detections[label] = [player_x, player_y]
|
||||||
|
|
||||||
|
# 道具和怪物可以多个
|
||||||
elif label in ['daojv', 'gw']:
|
elif label in ['daojv', 'gw']:
|
||||||
player_x = int(left + (right - left) / 2) + 3
|
player_x = int(left + (right - left) / 2) + 3
|
||||||
player_y = int(top + (bottom - top) / 2) + 40
|
player_y = int(top + (bottom - top) / 2) + 40
|
||||||
@@ -181,9 +210,34 @@ def yolo_shibie(im_PIL, detections, model, enhance_enabled=False, enhance_params
|
|||||||
if label not in detections:
|
if label not in detections:
|
||||||
detections[label] = []
|
detections[label] = []
|
||||||
detections[label].append([player_x, player_y])
|
detections[label].append([player_x, player_y])
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ 处理检测框时出错: {e}")
|
print(f"⚠️ 处理检测框时出错: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# 处理npc4:如果检测到多个,选择最合适的
|
||||||
|
if npc4_candidates:
|
||||||
|
# 按置信度排序,选择置信度最高的
|
||||||
|
npc4_candidates.sort(key=lambda x: x['confidence'], reverse=True)
|
||||||
|
|
||||||
|
# 选择最佳候选(置信度最高且面积合理)
|
||||||
|
best_npc4 = None
|
||||||
|
for candidate in npc4_candidates:
|
||||||
|
# 置信度阈值:至少0.3(可根据实际情况调整)
|
||||||
|
if candidate['confidence'] >= 0.3:
|
||||||
|
# 检查检测框面积是否合理(避免过小的误检)
|
||||||
|
area = candidate['area']
|
||||||
|
if area > 100: # 最小面积阈值
|
||||||
|
best_npc4 = candidate
|
||||||
|
break
|
||||||
|
|
||||||
|
if best_npc4:
|
||||||
|
detections['npc4'] = best_npc4['position']
|
||||||
|
# 可选:输出调试信息
|
||||||
|
# print(f"✅ 检测到npc4: 位置={best_npc4['position']}, 置信度={best_npc4['confidence']:.2f}")
|
||||||
|
elif len(npc4_candidates) == 1:
|
||||||
|
# 如果只有一个候选,即使置信度较低也使用
|
||||||
|
detections['npc4'] = npc4_candidates[0]['position']
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"⚠️ YOLO检测出错: {e}")
|
print(f"⚠️ YOLO检测出错: {e}")
|
||||||
|
|||||||
226
yolotest2.py
Normal file
226
yolotest2.py
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
"""
|
||||||
|
从main.py提取的YOLO识别测试文件
|
||||||
|
使用与main.py相同的识别逻辑
|
||||||
|
"""
|
||||||
|
import cv2
|
||||||
|
from utils.get_image import GetImage
|
||||||
|
from ultralytics import YOLO
|
||||||
|
from config import config_manager
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 检查模型文件是否存在
|
||||||
|
model_path = r"best.pt"
|
||||||
|
model0_path = r"best0.pt"
|
||||||
|
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
print(f"❌ 模型文件不存在: {model_path}")
|
||||||
|
exit(1)
|
||||||
|
if not os.path.exists(model0_path):
|
||||||
|
print(f"❌ 模型文件不存在: {model0_path}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
# 加载YOLO模型(与main.py保持一致)
|
||||||
|
try:
|
||||||
|
model = YOLO(model_path).to('cuda')
|
||||||
|
model0 = YOLO(model0_path).to('cuda')
|
||||||
|
print(f"✅ 模型加载成功: {model_path}")
|
||||||
|
print(f"✅ 模型加载成功: {model0_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 模型加载失败: {e}")
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
def yolo_shibie(im_PIL, detections, model):
|
||||||
|
"""
|
||||||
|
YOLO识别函数(与main.py中的实现完全一致)
|
||||||
|
:param im_PIL: PIL图像对象
|
||||||
|
:param detections: 检测结果字典
|
||||||
|
:param model: YOLO模型
|
||||||
|
:return: 更新后的detections字典
|
||||||
|
"""
|
||||||
|
results = model(im_PIL) # 目标检测
|
||||||
|
for result in results:
|
||||||
|
for i in range(len(result.boxes.xyxy)):
|
||||||
|
left, top, right, bottom = result.boxes.xyxy[i]
|
||||||
|
scalar_tensor = result.boxes.cls[i]
|
||||||
|
value = scalar_tensor.item()
|
||||||
|
label = result.names[int(value)]
|
||||||
|
if label == 'center' or label == 'next' or label == 'boss' or label == 'zhaozi':
|
||||||
|
player_x = int(left + (right - left) / 2)
|
||||||
|
player_y = int(top + (bottom - top) / 2) + 30
|
||||||
|
RW = [player_x, player_y]
|
||||||
|
detections[label] = RW
|
||||||
|
elif label == 'daojv' or label == 'gw':
|
||||||
|
player_x = int(left + (right - left) / 2)
|
||||||
|
player_y = int(top + (bottom - top) / 2) + 30
|
||||||
|
RW = [player_x, player_y]
|
||||||
|
detections[label].append(RW)
|
||||||
|
elif label == 'npc1' or label == 'npc2' or label == 'npc3' or label == 'npc4':
|
||||||
|
player_x = int(left + (right - left) / 2)
|
||||||
|
player_y = int(bottom) + 30
|
||||||
|
RW = [player_x, player_y]
|
||||||
|
detections[label] = RW
|
||||||
|
return detections
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
print("="*60)
|
||||||
|
print("YOLO识别测试(main.py逻辑)")
|
||||||
|
print("="*60)
|
||||||
|
|
||||||
|
# 从配置加载采集卡设置
|
||||||
|
active_group = config_manager.get_active_group()
|
||||||
|
|
||||||
|
if active_group is None:
|
||||||
|
print("⚠️ 没有活动的配置组,使用默认设置")
|
||||||
|
print("提示: 可以运行 python gui_config.py 设置配置")
|
||||||
|
cam_index = 0
|
||||||
|
width = 1920
|
||||||
|
height = 1080
|
||||||
|
use_model = model # 默认使用model
|
||||||
|
else:
|
||||||
|
print(f"📋 使用配置组: {active_group['name']}")
|
||||||
|
cam_index = active_group['camera_index']
|
||||||
|
width = active_group['camera_width']
|
||||||
|
height = active_group['camera_height']
|
||||||
|
use_model = model0 # 城镇中使用model0
|
||||||
|
print(f" 使用模型: model0 (best0.pt) - 用于城镇识别")
|
||||||
|
|
||||||
|
print(f" 采集卡索引: {cam_index}")
|
||||||
|
print(f" 分辨率: {width}x{height}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# 初始化采集卡
|
||||||
|
print("🔧 正在初始化采集卡...")
|
||||||
|
get_image = GetImage(
|
||||||
|
cam_index=cam_index,
|
||||||
|
width=width,
|
||||||
|
height=height
|
||||||
|
)
|
||||||
|
|
||||||
|
if get_image.cap is None:
|
||||||
|
print("❌ 采集卡初始化失败")
|
||||||
|
print("请检查:")
|
||||||
|
print("1. 采集卡是否正确连接")
|
||||||
|
print("2. 采集卡索引是否正确")
|
||||||
|
print("3. 采集卡驱动是否安装")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("✅ 采集卡初始化成功")
|
||||||
|
print("\n快捷键:")
|
||||||
|
print(" 'q' 或 ESC - 退出")
|
||||||
|
print(" 'm' - 切换模型 (model/model0)")
|
||||||
|
print(" 'd' - 显示/隐藏检测信息")
|
||||||
|
print()
|
||||||
|
|
||||||
|
try:
|
||||||
|
frame_count = 0
|
||||||
|
show_detections = True # 是否显示检测信息
|
||||||
|
current_model = use_model # 当前使用的模型
|
||||||
|
current_model_name = "model0" if use_model == model0 else "model"
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 获取帧
|
||||||
|
frame_data = get_image.get_frame()
|
||||||
|
|
||||||
|
if frame_data is None:
|
||||||
|
print("⚠️ 无法获取帧,跳过...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# frame_data 是 [im_opencv_rgb, im_PIL] 格式
|
||||||
|
im_opencv_rgb, im_PIL = frame_data
|
||||||
|
|
||||||
|
if im_PIL is None:
|
||||||
|
print("⚠️ PIL图像为空,跳过...")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 初始化检测结果字典
|
||||||
|
detections = {
|
||||||
|
'center': None, 'next': None,
|
||||||
|
'npc1': None, 'npc2': None, 'npc3': None, 'npc4': None,
|
||||||
|
'boss': None, 'zhaozi': None,
|
||||||
|
'daojv': [], 'gw': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行YOLO检测(使用main.py的逻辑)
|
||||||
|
detections = yolo_shibie(im_PIL, detections, current_model)
|
||||||
|
|
||||||
|
# 获取绘制好框的图像用于显示
|
||||||
|
try:
|
||||||
|
results = current_model(im_PIL)
|
||||||
|
result = results[0]
|
||||||
|
frame_with_boxes_rgb = result.plot()
|
||||||
|
frame_with_boxes_bgr = cv2.cvtColor(frame_with_boxes_rgb, cv2.COLOR_RGB2BGR)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 绘制检测框失败: {e}")
|
||||||
|
frame_with_boxes_bgr = cv2.cvtColor(im_opencv_rgb, cv2.COLOR_RGB2BGR)
|
||||||
|
|
||||||
|
# 在图像上显示检测信息
|
||||||
|
if show_detections:
|
||||||
|
# 显示模型名称
|
||||||
|
cv2.putText(frame_with_boxes_bgr, f"Model: {current_model_name}",
|
||||||
|
(10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)
|
||||||
|
|
||||||
|
# 显示检测到的目标
|
||||||
|
y_offset = 60
|
||||||
|
detected_items = []
|
||||||
|
for key, value in detections.items():
|
||||||
|
if value is not None and value != []:
|
||||||
|
if key in ['daojv', 'gw']:
|
||||||
|
detected_items.append(f"{key}: {len(value)}个")
|
||||||
|
else:
|
||||||
|
detected_items.append(f"{key}: {value}")
|
||||||
|
|
||||||
|
if detected_items:
|
||||||
|
text = f"Detected: {', '.join(detected_items[:5])}" # 最多显示5个
|
||||||
|
if len(detected_items) > 5:
|
||||||
|
text += f" ... (+{len(detected_items)-5})"
|
||||||
|
cv2.putText(frame_with_boxes_bgr, text,
|
||||||
|
(10, y_offset), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 0), 2)
|
||||||
|
|
||||||
|
# 显示图像
|
||||||
|
cv2.imshow("YOLO Detection (main.py logic)", frame_with_boxes_bgr)
|
||||||
|
|
||||||
|
# 检查按键
|
||||||
|
key = cv2.waitKey(1) & 0xFF
|
||||||
|
if key in [27, ord('q'), ord('Q')]:
|
||||||
|
print("\n用户退出")
|
||||||
|
break
|
||||||
|
elif key == ord('m') or key == ord('M'):
|
||||||
|
# 切换模型
|
||||||
|
if current_model == model:
|
||||||
|
current_model = model0
|
||||||
|
current_model_name = "model0"
|
||||||
|
else:
|
||||||
|
current_model = model
|
||||||
|
current_model_name = "model"
|
||||||
|
print(f"切换模型: {current_model_name}")
|
||||||
|
elif key == ord('d') or key == ord('D'):
|
||||||
|
show_detections = not show_detections
|
||||||
|
print(f"显示检测信息: {'开启' if show_detections else '关闭'}")
|
||||||
|
|
||||||
|
frame_count += 1
|
||||||
|
if frame_count % 30 == 0: # 每30帧打印一次
|
||||||
|
print(f"📊 已处理 {frame_count} 帧 (模型: {current_model_name})")
|
||||||
|
# 打印有检测到的目标
|
||||||
|
detected_items = {k: v for k, v in detections.items() if v is not None and v != []}
|
||||||
|
if detected_items:
|
||||||
|
print(f" 检测到: {detected_items}")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n用户中断测试")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n❌ 测试过程中发生错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# 清理资源
|
||||||
|
get_image.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
print("🔚 测试结束")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
Reference in New Issue
Block a user