sam_crack/run_point_evaluation.py
2025-12-24 13:43:34 +08:00

224 lines
7.7 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
"""
import argparse
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from src.tasks.config import TaskConfig, TaskStepConfig
from src.tasks.io import load_task_from_toml
from src.tasks.pipeline import TaskRunner
@dataclass
class PointCLIArgs:
data_root: str
test_file: str
model_id: str
point_configs: List[int]
per_component: bool
num_vis: int
skip_inference: bool
skip_evaluation: bool
skip_visualization: bool
skip_comparison: bool
comparison_dir: str
config_name: str
task_file: Optional[str]
def parse_args() -> PointCLIArgs:
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
parser.add_argument(
"--config_name",
type=str,
default="sam2_bbox_prompt",
help="ProjectConfig 名称(来自 ConfigRegistry",
)
parser.add_argument(
"--task_file",
type=str,
default=None,
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
)
args = parser.parse_args()
return PointCLIArgs(
data_root=args.data_root,
test_file=args.test_file,
model_id=args.model_id,
point_configs=args.point_configs,
per_component=args.per_component,
num_vis=args.num_vis,
skip_inference=args.skip_inference,
skip_evaluation=args.skip_evaluation,
skip_visualization=args.skip_visualization,
skip_comparison=args.skip_comparison,
comparison_dir=args.comparison_dir,
config_name=args.config_name,
task_file=args.task_file,
)
def default_output_dir(num_points: int, per_component: bool) -> str:
if per_component:
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
return f"./results/point_prompt_{num_points}pts_hf"
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
steps: List[TaskStepConfig] = []
common = {
"data_root": args.data_root,
"test_file": args.test_file,
"model_id": args.model_id,
"output_dir": output_dir,
}
if not args.skip_inference:
steps.append(
TaskStepConfig(
kind="point_inference",
params={
**common,
"num_points": num_points,
"per_component": args.per_component,
},
)
)
if not args.skip_evaluation:
steps.append(
TaskStepConfig(
kind="legacy_evaluation",
params={
**common,
"pred_dir": f"{output_dir}/predictions",
"compute_skeleton": True,
},
)
)
if not args.skip_visualization:
steps.append(
TaskStepConfig(
kind="legacy_visualization",
params={
**common,
"pred_dir": f"{output_dir}/predictions",
"results_csv": f"{output_dir}/evaluation_results.csv",
"num_samples": args.num_vis,
"save_all": False,
"create_metrics_plot": True,
},
)
)
return TaskConfig(
name=f"point_cli_{num_points}",
description=f"Legacy point prompt pipeline ({num_points} pts)",
project_config_name=args.config_name,
steps=steps,
)
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
csv_path = Path(output_dir) / "evaluation_results.csv"
if not csv_path.exists():
return None
return pd.read_csv(csv_path)
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
if not results:
return
os.makedirs(output_dir, exist_ok=True)
summary_rows = []
for num_points, df in results.items():
summary_rows.append(
{
"num_points": num_points,
"iou_mean": df["iou"].mean(),
"iou_std": df["iou"].std(),
"dice_mean": df["dice"].mean(),
"dice_std": df["dice"].std(),
"f1_mean": df["f1_score"].mean(),
"f1_std": df["f1_score"].std(),
"precision_mean": df["precision"].mean(),
"recall_mean": df["recall"].mean(),
}
)
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
summary_path.parent.mkdir(parents=True, exist_ok=True)
df_summary.to_csv(summary_path, index=False)
import matplotlib.pyplot as plt
metrics_to_plot = [
("iou_mean", "iou_std", "IoU"),
("dice_mean", "dice_std", "Dice"),
("f1_mean", "f1_std", "F1-Score"),
]
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
xs = df_summary["num_points"].tolist()
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
ax.errorbar(
xs,
df_summary[mean_col],
yerr=df_summary[std_col],
marker="o",
capsize=5,
linewidth=2,
markersize=8,
)
ax.set_xlabel("Number of Points", fontsize=12)
ax.set_ylabel(title, fontsize=12)
ax.set_title(f"{title} vs Number of Points", fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xticks(xs)
plt.tight_layout()
plot_path = summary_path.with_name("performance_comparison.png")
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
plt.close(fig)
def main() -> None:
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.task_file:
task = load_task_from_toml(args.task_file)
TaskRunner(task).run()
return
comparison_data: Dict[int, pd.DataFrame] = {}
for num_points in args.point_configs:
output_dir = default_output_dir(num_points, args.per_component)
task = build_task_for_points(args, num_points, output_dir)
if not task.steps:
continue
TaskRunner(task).run()
if not args.skip_comparison and not args.skip_evaluation:
df = load_results_csv(output_dir)
if df is not None:
comparison_data[num_points] = df
if not args.skip_comparison and comparison_data:
compare_results(comparison_data, args.comparison_dir)
if __name__ == "__main__":
main()