sam_crack/run_bbox_evaluation.py

250 lines
8.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
SAM2 边界框提示方式完整评估流程
包括:推理 -> 评估 -> 可视化
"""
import os
import sys
import argparse
import time
from pathlib import Path
# 添加 src 目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor
from evaluation import evaluate_test_set
from visualization import visualize_test_set, create_metrics_distribution_plot
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="SAM2 边界框提示方式 - Crack500 数据集完整评估"
)
# 数据集参数
parser.add_argument(
"--data_root", type=str, default="./crack500",
help="数据集根目录"
)
parser.add_argument(
"--test_file", type=str, default="./crack500/test.txt",
help="测试集文件路径"
)
# 模型参数
parser.add_argument(
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
help="SAM2 模型检查点路径"
)
parser.add_argument(
"--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml",
help="SAM2 模型配置文件"
)
# 输出参数
parser.add_argument(
"--output_dir", type=str, default="./results/bbox_prompt",
help="输出目录"
)
# 边界框参数
parser.add_argument(
"--expand_ratio", type=float, default=0.05,
help="边界框扩展比例 (0.0-1.0)"
)
# 可视化参数
parser.add_argument(
"--num_vis", type=int, default=20,
help="可视化样本数量"
)
parser.add_argument(
"--vis_all", action="store_true",
help="可视化所有样本"
)
# 流程控制
parser.add_argument(
"--skip_inference", action="store_true",
help="跳过推理步骤(使用已有预测结果)"
)
parser.add_argument(
"--skip_evaluation", action="store_true",
help="跳过评估步骤"
)
parser.add_argument(
"--skip_visualization", action="store_true",
help="跳过可视化步骤"
)
return parser.parse_args()
def main():
"""主函数"""
args = parse_args()
print("=" * 80)
print("SAM2 边界框提示方式 - Crack500 数据集完整评估")
print("=" * 80)
print(f"数据集根目录: {args.data_root}")
print(f"测试集文件: {args.test_file}")
print(f"模型检查点: {args.checkpoint}")
print(f"模型配置: {args.model_cfg}")
print(f"边界框扩展比例: {args.expand_ratio * 100}%")
print(f"输出目录: {args.output_dir}")
print("=" * 80)
# 检查必要文件
if not os.path.exists(args.data_root):
print(f"\n错误: 数据集目录不存在 {args.data_root}")
return
if not os.path.exists(args.test_file):
print(f"\n错误: 测试集文件不存在 {args.test_file}")
return
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# ========== 步骤 1: 推理 ==========
if not args.skip_inference:
print("\n" + "=" * 80)
print("步骤 1/3: 使用 SAM2 进行推理")
print("=" * 80)
# 检查模型文件
if not os.path.exists(args.checkpoint):
print(f"\n错误: 模型检查点不存在 {args.checkpoint}")
print("请先下载 SAM2 模型权重!")
print("运行: cd sam2/checkpoints && ./download_ckpts.sh")
return
try:
import torch
# 检查 CUDA
if not torch.cuda.is_available():
print("警告: CUDA 不可用,将使用 CPU速度会很慢")
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 加载模型
print("\n加载 SAM2 模型...")
start_time = time.time()
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
predictor = SAM2ImagePredictor(sam2_model)
print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s")
# 处理测试集
print("\n开始推理...")
start_time = time.time()
results = process_test_set(
data_root=args.data_root,
test_file=args.test_file,
predictor=predictor,
output_dir=args.output_dir,
expand_ratio=args.expand_ratio
)
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
print(f"成功处理 {len(results)} 张图像")
except Exception as e:
print(f"\n推理过程出错: {str(e)}")
import traceback
traceback.print_exc()
return
else:
print("\n跳过推理步骤(使用已有预测结果)")
# ========== 步骤 2: 评估 ==========
if not args.skip_evaluation:
print("\n" + "=" * 80)
print("步骤 2/3: 评估预测结果")
print("=" * 80)
pred_dir = os.path.join(args.output_dir, "predictions")
if not os.path.exists(pred_dir):
print(f"\n错误: 预测目录不存在 {pred_dir}")
print("请先运行推理步骤!")
return
try:
start_time = time.time()
df_results = evaluate_test_set(
data_root=args.data_root,
test_file=args.test_file,
pred_dir=pred_dir,
output_dir=args.output_dir,
compute_skeleton=True
)
print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s")
except Exception as e:
print(f"\n评估过程出错: {str(e)}")
import traceback
traceback.print_exc()
return
else:
print("\n跳过评估步骤")
# ========== 步骤 3: 可视化 ==========
if not args.skip_visualization:
print("\n" + "=" * 80)
print("步骤 3/3: 生成可视化结果")
print("=" * 80)
pred_dir = os.path.join(args.output_dir, "predictions")
results_csv = os.path.join(args.output_dir, "evaluation_results.csv")
if not os.path.exists(pred_dir):
print(f"\n错误: 预测目录不存在 {pred_dir}")
return
try:
start_time = time.time()
# 可视化样本
visualize_test_set(
data_root=args.data_root,
test_file=args.test_file,
pred_dir=pred_dir,
output_dir=args.output_dir,
results_csv=results_csv if os.path.exists(results_csv) else None,
num_samples=args.num_vis,
save_all=args.vis_all
)
# 创建指标分布图
if os.path.exists(results_csv):
create_metrics_distribution_plot(results_csv, args.output_dir)
print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s")
except Exception as e:
print(f"\n可视化过程出错: {str(e)}")
import traceback
traceback.print_exc()
return
else:
print("\n跳过可视化步骤")
# ========== 完成 ==========
print("\n" + "=" * 80)
print("所有步骤完成!")
print("=" * 80)
print(f"\n结果保存在: {args.output_dir}")
print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}")
print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}")
print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}")
print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}")
print("=" * 80)
if __name__ == "__main__":
main()