#!/usr/bin/env python3 """ SAM2 边界框提示方式完整评估流程 包括:推理 -> 评估 -> 可视化 """ import os import sys import argparse import time from pathlib import Path # 添加 src 目录到路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor from evaluation import evaluate_test_set from visualization import visualize_test_set, create_metrics_distribution_plot def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser( description="SAM2 边界框提示方式 - Crack500 数据集完整评估" ) # 数据集参数 parser.add_argument( "--data_root", type=str, default="./crack500", help="数据集根目录" ) parser.add_argument( "--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径" ) # 模型参数 parser.add_argument( "--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt", help="SAM2 模型检查点路径" ) parser.add_argument( "--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml", help="SAM2 模型配置文件" ) # 输出参数 parser.add_argument( "--output_dir", type=str, default="./results/bbox_prompt", help="输出目录" ) # 边界框参数 parser.add_argument( "--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)" ) # 可视化参数 parser.add_argument( "--num_vis", type=int, default=20, help="可视化样本数量" ) parser.add_argument( "--vis_all", action="store_true", help="可视化所有样本" ) # 流程控制 parser.add_argument( "--skip_inference", action="store_true", help="跳过推理步骤(使用已有预测结果)" ) parser.add_argument( "--skip_evaluation", action="store_true", help="跳过评估步骤" ) parser.add_argument( "--skip_visualization", action="store_true", help="跳过可视化步骤" ) return parser.parse_args() def main(): """主函数""" args = parse_args() print("=" * 80) print("SAM2 边界框提示方式 - Crack500 数据集完整评估") print("=" * 80) print(f"数据集根目录: {args.data_root}") print(f"测试集文件: {args.test_file}") print(f"模型检查点: {args.checkpoint}") print(f"模型配置: {args.model_cfg}") print(f"边界框扩展比例: {args.expand_ratio * 100}%") print(f"输出目录: {args.output_dir}") print("=" * 80) # 检查必要文件 if not os.path.exists(args.data_root): print(f"\n错误: 数据集目录不存在 {args.data_root}") return if not os.path.exists(args.test_file): print(f"\n错误: 测试集文件不存在 {args.test_file}") return # 创建输出目录 os.makedirs(args.output_dir, exist_ok=True) # ========== 步骤 1: 推理 ========== if not args.skip_inference: print("\n" + "=" * 80) print("步骤 1/3: 使用 SAM2 进行推理") print("=" * 80) # 检查模型文件 if not os.path.exists(args.checkpoint): print(f"\n错误: 模型检查点不存在 {args.checkpoint}") print("请先下载 SAM2 模型权重!") print("运行: cd sam2/checkpoints && ./download_ckpts.sh") return try: import torch # 检查 CUDA if not torch.cuda.is_available(): print("警告: CUDA 不可用,将使用 CPU(速度会很慢)") else: print(f"使用 GPU: {torch.cuda.get_device_name(0)}") # 加载模型 print("\n加载 SAM2 模型...") start_time = time.time() sam2_model = build_sam2(args.model_cfg, args.checkpoint) predictor = SAM2ImagePredictor(sam2_model) print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s") # 处理测试集 print("\n开始推理...") start_time = time.time() results = process_test_set( data_root=args.data_root, test_file=args.test_file, predictor=predictor, output_dir=args.output_dir, expand_ratio=args.expand_ratio ) print(f"推理完成!耗时: {time.time() - start_time:.2f}s") print(f"成功处理 {len(results)} 张图像") except Exception as e: print(f"\n推理过程出错: {str(e)}") import traceback traceback.print_exc() return else: print("\n跳过推理步骤(使用已有预测结果)") # ========== 步骤 2: 评估 ========== if not args.skip_evaluation: print("\n" + "=" * 80) print("步骤 2/3: 评估预测结果") print("=" * 80) pred_dir = os.path.join(args.output_dir, "predictions") if not os.path.exists(pred_dir): print(f"\n错误: 预测目录不存在 {pred_dir}") print("请先运行推理步骤!") return try: start_time = time.time() df_results = evaluate_test_set( data_root=args.data_root, test_file=args.test_file, pred_dir=pred_dir, output_dir=args.output_dir, compute_skeleton=True ) print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s") except Exception as e: print(f"\n评估过程出错: {str(e)}") import traceback traceback.print_exc() return else: print("\n跳过评估步骤") # ========== 步骤 3: 可视化 ========== if not args.skip_visualization: print("\n" + "=" * 80) print("步骤 3/3: 生成可视化结果") print("=" * 80) pred_dir = os.path.join(args.output_dir, "predictions") results_csv = os.path.join(args.output_dir, "evaluation_results.csv") if not os.path.exists(pred_dir): print(f"\n错误: 预测目录不存在 {pred_dir}") return try: start_time = time.time() # 可视化样本 visualize_test_set( data_root=args.data_root, test_file=args.test_file, pred_dir=pred_dir, output_dir=args.output_dir, results_csv=results_csv if os.path.exists(results_csv) else None, num_samples=args.num_vis, save_all=args.vis_all ) # 创建指标分布图 if os.path.exists(results_csv): create_metrics_distribution_plot(results_csv, args.output_dir) print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s") except Exception as e: print(f"\n可视化过程出错: {str(e)}") import traceback traceback.print_exc() return else: print("\n跳过可视化步骤") # ========== 完成 ========== print("\n" + "=" * 80) print("所有步骤完成!") print("=" * 80) print(f"\n结果保存在: {args.output_dir}") print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}") print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}") print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}") print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}") print("=" * 80) if __name__ == "__main__": main()