#!/usr/bin/env python3 """ SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本) """ import argparse import logging import os from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional import pandas as pd from src.tasks.config import TaskConfig, TaskStepConfig from src.tasks.io import load_task_from_toml from src.tasks.pipeline import TaskRunner @dataclass class PointCLIArgs: data_root: str test_file: str model_id: str point_configs: List[int] per_component: bool num_vis: int skip_inference: bool skip_evaluation: bool skip_visualization: bool skip_comparison: bool comparison_dir: str config_name: str task_file: Optional[str] def parse_args() -> PointCLIArgs: parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验") parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录") parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径") parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID") parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置") parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点") parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量") parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤") parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤") parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤") parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比") parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录") parser.add_argument( "--config_name", type=str, default="sam2_bbox_prompt", help="ProjectConfig 名称(来自 ConfigRegistry)", ) parser.add_argument( "--task_file", type=str, default=None, help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)", ) args = parser.parse_args() return PointCLIArgs( data_root=args.data_root, test_file=args.test_file, model_id=args.model_id, point_configs=args.point_configs, per_component=args.per_component, num_vis=args.num_vis, skip_inference=args.skip_inference, skip_evaluation=args.skip_evaluation, skip_visualization=args.skip_visualization, skip_comparison=args.skip_comparison, comparison_dir=args.comparison_dir, config_name=args.config_name, task_file=args.task_file, ) def default_output_dir(num_points: int, per_component: bool) -> str: if per_component: return f"./results/point_prompt_{num_points}pts_per_comp_hf" return f"./results/point_prompt_{num_points}pts_hf" def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig: steps: List[TaskStepConfig] = [] common = { "data_root": args.data_root, "test_file": args.test_file, "model_id": args.model_id, "output_dir": output_dir, } if not args.skip_inference: steps.append( TaskStepConfig( kind="point_inference", params={ **common, "num_points": num_points, "per_component": args.per_component, }, ) ) if not args.skip_evaluation: steps.append( TaskStepConfig( kind="legacy_evaluation", params={ **common, "pred_dir": f"{output_dir}/predictions", "compute_skeleton": True, }, ) ) if not args.skip_visualization: steps.append( TaskStepConfig( kind="legacy_visualization", params={ **common, "pred_dir": f"{output_dir}/predictions", "results_csv": f"{output_dir}/evaluation_results.csv", "num_samples": args.num_vis, "save_all": False, "create_metrics_plot": True, }, ) ) return TaskConfig( name=f"point_cli_{num_points}", description=f"Legacy point prompt pipeline ({num_points} pts)", project_config_name=args.config_name, steps=steps, ) def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]: csv_path = Path(output_dir) / "evaluation_results.csv" if not csv_path.exists(): return None return pd.read_csv(csv_path) def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None: if not results: return os.makedirs(output_dir, exist_ok=True) summary_rows = [] for num_points, df in results.items(): summary_rows.append( { "num_points": num_points, "iou_mean": df["iou"].mean(), "iou_std": df["iou"].std(), "dice_mean": df["dice"].mean(), "dice_std": df["dice"].std(), "f1_mean": df["f1_score"].mean(), "f1_std": df["f1_score"].std(), "precision_mean": df["precision"].mean(), "recall_mean": df["recall"].mean(), } ) df_summary = pd.DataFrame(summary_rows).sort_values("num_points") summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv" summary_path.parent.mkdir(parents=True, exist_ok=True) df_summary.to_csv(summary_path, index=False) import matplotlib.pyplot as plt metrics_to_plot = [ ("iou_mean", "iou_std", "IoU"), ("dice_mean", "dice_std", "Dice"), ("f1_mean", "f1_std", "F1-Score"), ] fig, axes = plt.subplots(1, 3, figsize=(15, 5)) xs = df_summary["num_points"].tolist() for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot): ax.errorbar( xs, df_summary[mean_col], yerr=df_summary[std_col], marker="o", capsize=5, linewidth=2, markersize=8, ) ax.set_xlabel("Number of Points", fontsize=12) ax.set_ylabel(title, fontsize=12) ax.set_title(f"{title} vs Number of Points", fontsize=14) ax.grid(True, alpha=0.3) ax.set_xticks(xs) plt.tight_layout() plot_path = summary_path.with_name("performance_comparison.png") fig.savefig(plot_path, dpi=150, bbox_inches="tight") plt.close(fig) def main() -> None: logging.basicConfig(level=logging.INFO) args = parse_args() if args.task_file: task = load_task_from_toml(args.task_file) TaskRunner(task).run() return comparison_data: Dict[int, pd.DataFrame] = {} for num_points in args.point_configs: output_dir = default_output_dir(num_points, args.per_component) task = build_task_for_points(args, num_points, output_dir) if not task.steps: continue TaskRunner(task).run() if not args.skip_comparison and not args.skip_evaluation: df = load_results_csv(output_dir) if df is not None: comparison_data[num_points] = df if not args.skip_comparison and comparison_data: compare_results(comparison_data, args.comparison_dir) if __name__ == "__main__": main()