#!/usr/bin/env python3 """ SAM2 点提示方式完整评估流程 支持 1, 3, 5 个点的对比实验 """ import os import sys import argparse import time from pathlib import Path # 添加 src 目录到路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor from evaluation import evaluate_test_set from visualization import visualize_test_set, create_metrics_distribution_plot def run_single_experiment( data_root: str, test_file: str, checkpoint: str, model_cfg: str, num_points: int, per_component: bool = False ): """运行单个点数的实验""" # 设置输出目录 if per_component: output_dir = f"./results/point_prompt_{num_points}pts_per_comp" else: output_dir = f"./results/point_prompt_{num_points}pts" print("\n" + "=" * 80) print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})") print("=" * 80) # 加载模型 print("\n加载 SAM2 模型...") import torch sam2_model = build_sam2(model_cfg, checkpoint) predictor = SAM2ImagePredictor(sam2_model) print("模型加载完成!") # 推理 print(f"\n步骤 1/3: 推理 ({num_points} 个点)") start_time = time.time() results = process_test_set( data_root=data_root, test_file=test_file, predictor=predictor, output_dir=output_dir, num_points=num_points, per_component=per_component ) print(f"推理完成!耗时: {time.time() - start_time:.2f}s") # 评估 print(f"\n步骤 2/3: 评估") pred_dir = os.path.join(output_dir, "predictions") start_time = time.time() df_results = evaluate_test_set( data_root=data_root, test_file=test_file, pred_dir=pred_dir, output_dir=output_dir, compute_skeleton=True ) print(f"评估完成!耗时: {time.time() - start_time:.2f}s") # 可视化 print(f"\n步骤 3/3: 可视化") results_csv = os.path.join(output_dir, "evaluation_results.csv") start_time = time.time() visualize_test_set( data_root=data_root, test_file=test_file, pred_dir=pred_dir, output_dir=output_dir, results_csv=results_csv, num_samples=10, save_all=False ) create_metrics_distribution_plot(results_csv, output_dir) print(f"可视化完成!耗时: {time.time() - start_time:.2f}s") return df_results def compare_results(results_dict: dict, output_dir: str = "./results"): """对比不同点数的结果""" import pandas as pd import matplotlib.pyplot as plt print("\n" + "=" * 80) print("对比不同点数的性能") print("=" * 80) # 收集所有结果 summary = [] for num_points, df in results_dict.items(): metrics = { 'num_points': num_points, 'iou_mean': df['iou'].mean(), 'iou_std': df['iou'].std(), 'dice_mean': df['dice'].mean(), 'dice_std': df['dice'].std(), 'f1_mean': df['f1_score'].mean(), 'f1_std': df['f1_score'].std(), 'precision_mean': df['precision'].mean(), 'recall_mean': df['recall'].mean(), } summary.append(metrics) df_summary = pd.DataFrame(summary) # 打印对比表格 print("\n性能对比:") print(df_summary.to_string(index=False)) # 保存对比结果 comparison_dir = os.path.join(output_dir, "point_comparison") os.makedirs(comparison_dir, exist_ok=True) csv_path = os.path.join(comparison_dir, "comparison_summary.csv") df_summary.to_csv(csv_path, index=False) print(f"\n对比结果已保存到: {csv_path}") # 绘制对比图 fig, axes = plt.subplots(1, 3, figsize=(15, 5)) metrics_to_plot = [ ('iou_mean', 'iou_std', 'IoU'), ('dice_mean', 'dice_std', 'Dice'), ('f1_mean', 'f1_std', 'F1-Score') ] for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot): ax = axes[idx] x = df_summary['num_points'] y = df_summary[mean_col] yerr = df_summary[std_col] ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8) ax.set_xlabel('Number of Points', fontsize=12) ax.set_ylabel(title, fontsize=12) ax.set_title(f'{title} vs Number of Points', fontsize=14) ax.grid(True, alpha=0.3) ax.set_xticks(x) plt.tight_layout() plot_path = os.path.join(comparison_dir, "performance_comparison.png") fig.savefig(plot_path, dpi=150, bbox_inches='tight') plt.close(fig) print(f"对比图已保存到: {plot_path}") return df_summary def main(): """主函数""" parser = argparse.ArgumentParser( description="SAM2 点提示方式 - 多点数对比实验" ) # 数据集参数 parser.add_argument( "--data_root", type=str, default="./crack500", help="数据集根目录" ) parser.add_argument( "--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径" ) # 模型参数 parser.add_argument( "--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt", help="SAM2 模型检查点路径" ) parser.add_argument( "--model_cfg", type=str, default="sam2.1_hiera_s.yaml", help="SAM2 模型配置文件" ) # 实验参数 parser.add_argument( "--point_configs", type=int, nargs='+', default=[1, 3, 5], help="要测试的点数配置" ) parser.add_argument( "--per_component", action="store_true", help="为每个连通域独立采样" ) parser.add_argument( "--skip_comparison", action="store_true", help="跳过对比分析" ) args = parser.parse_args() print("=" * 80) print("SAM2 点提示方式 - 多点数对比实验") print("=" * 80) print(f"数据集根目录: {args.data_root}") print(f"测试集文件: {args.test_file}") print(f"模型检查点: {args.checkpoint}") print(f"点数配置: {args.point_configs}") print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}") print("=" * 80) # 检查 CUDA import torch if not torch.cuda.is_available(): print("警告: CUDA 不可用,将使用 CPU(速度会很慢)") else: print(f"使用 GPU: {torch.cuda.get_device_name(0)}") # 运行所有实验 results_dict = {} for num_points in args.point_configs: try: df_results = run_single_experiment( data_root=args.data_root, test_file=args.test_file, checkpoint=args.checkpoint, model_cfg=args.model_cfg, num_points=num_points, per_component=args.per_component ) results_dict[num_points] = df_results except Exception as e: print(f"\n实验失败 ({num_points} 个点): {str(e)}") import traceback traceback.print_exc() continue # 对比分析 if not args.skip_comparison and len(results_dict) > 1: try: compare_results(results_dict) except Exception as e: print(f"\n对比分析失败: {str(e)}") import traceback traceback.print_exc() # 完成 print("\n" + "=" * 80) print("所有实验完成!") print("=" * 80) print("\n结果目录:") for num_points in args.point_configs: if args.per_component: output_dir = f"./results/point_prompt_{num_points}pts_per_comp" else: output_dir = f"./results/point_prompt_{num_points}pts" print(f" - {num_points} 个点: {output_dir}") if not args.skip_comparison and len(results_dict) > 1: print(f" - 对比分析: ./results/point_comparison") print("=" * 80) if __name__ == "__main__": main()