224 lines
7.7 KiB
Python
Executable File
224 lines
7.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
||
"""
|
||
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
|
||
"""
|
||
|
||
import argparse
|
||
import logging
|
||
import os
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Dict, List, Optional
|
||
|
||
import pandas as pd
|
||
|
||
from src.tasks.config import TaskConfig, TaskStepConfig
|
||
from src.tasks.io import load_task_from_toml
|
||
from src.tasks.pipeline import TaskRunner
|
||
|
||
|
||
@dataclass
|
||
class PointCLIArgs:
|
||
data_root: str
|
||
test_file: str
|
||
model_id: str
|
||
point_configs: List[int]
|
||
per_component: bool
|
||
num_vis: int
|
||
skip_inference: bool
|
||
skip_evaluation: bool
|
||
skip_visualization: bool
|
||
skip_comparison: bool
|
||
comparison_dir: str
|
||
config_name: str
|
||
task_file: Optional[str]
|
||
|
||
|
||
def parse_args() -> PointCLIArgs:
|
||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
|
||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
|
||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
|
||
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
|
||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
|
||
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
|
||
parser.add_argument(
|
||
"--config_name",
|
||
type=str,
|
||
default="sam2_bbox_prompt",
|
||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||
)
|
||
parser.add_argument(
|
||
"--task_file",
|
||
type=str,
|
||
default=None,
|
||
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
|
||
)
|
||
args = parser.parse_args()
|
||
return PointCLIArgs(
|
||
data_root=args.data_root,
|
||
test_file=args.test_file,
|
||
model_id=args.model_id,
|
||
point_configs=args.point_configs,
|
||
per_component=args.per_component,
|
||
num_vis=args.num_vis,
|
||
skip_inference=args.skip_inference,
|
||
skip_evaluation=args.skip_evaluation,
|
||
skip_visualization=args.skip_visualization,
|
||
skip_comparison=args.skip_comparison,
|
||
comparison_dir=args.comparison_dir,
|
||
config_name=args.config_name,
|
||
task_file=args.task_file,
|
||
)
|
||
|
||
|
||
def default_output_dir(num_points: int, per_component: bool) -> str:
|
||
if per_component:
|
||
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
|
||
return f"./results/point_prompt_{num_points}pts_hf"
|
||
|
||
|
||
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
|
||
steps: List[TaskStepConfig] = []
|
||
common = {
|
||
"data_root": args.data_root,
|
||
"test_file": args.test_file,
|
||
"model_id": args.model_id,
|
||
"output_dir": output_dir,
|
||
}
|
||
if not args.skip_inference:
|
||
steps.append(
|
||
TaskStepConfig(
|
||
kind="point_inference",
|
||
params={
|
||
**common,
|
||
"num_points": num_points,
|
||
"per_component": args.per_component,
|
||
},
|
||
)
|
||
)
|
||
if not args.skip_evaluation:
|
||
steps.append(
|
||
TaskStepConfig(
|
||
kind="legacy_evaluation",
|
||
params={
|
||
**common,
|
||
"pred_dir": f"{output_dir}/predictions",
|
||
"compute_skeleton": True,
|
||
},
|
||
)
|
||
)
|
||
if not args.skip_visualization:
|
||
steps.append(
|
||
TaskStepConfig(
|
||
kind="legacy_visualization",
|
||
params={
|
||
**common,
|
||
"pred_dir": f"{output_dir}/predictions",
|
||
"results_csv": f"{output_dir}/evaluation_results.csv",
|
||
"num_samples": args.num_vis,
|
||
"save_all": False,
|
||
"create_metrics_plot": True,
|
||
},
|
||
)
|
||
)
|
||
return TaskConfig(
|
||
name=f"point_cli_{num_points}",
|
||
description=f"Legacy point prompt pipeline ({num_points} pts)",
|
||
project_config_name=args.config_name,
|
||
steps=steps,
|
||
)
|
||
|
||
|
||
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
|
||
csv_path = Path(output_dir) / "evaluation_results.csv"
|
||
if not csv_path.exists():
|
||
return None
|
||
return pd.read_csv(csv_path)
|
||
|
||
|
||
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
|
||
if not results:
|
||
return
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
summary_rows = []
|
||
for num_points, df in results.items():
|
||
summary_rows.append(
|
||
{
|
||
"num_points": num_points,
|
||
"iou_mean": df["iou"].mean(),
|
||
"iou_std": df["iou"].std(),
|
||
"dice_mean": df["dice"].mean(),
|
||
"dice_std": df["dice"].std(),
|
||
"f1_mean": df["f1_score"].mean(),
|
||
"f1_std": df["f1_score"].std(),
|
||
"precision_mean": df["precision"].mean(),
|
||
"recall_mean": df["recall"].mean(),
|
||
}
|
||
)
|
||
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
|
||
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
|
||
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
||
df_summary.to_csv(summary_path, index=False)
|
||
|
||
import matplotlib.pyplot as plt
|
||
|
||
metrics_to_plot = [
|
||
("iou_mean", "iou_std", "IoU"),
|
||
("dice_mean", "dice_std", "Dice"),
|
||
("f1_mean", "f1_std", "F1-Score"),
|
||
]
|
||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||
xs = df_summary["num_points"].tolist()
|
||
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
|
||
ax.errorbar(
|
||
xs,
|
||
df_summary[mean_col],
|
||
yerr=df_summary[std_col],
|
||
marker="o",
|
||
capsize=5,
|
||
linewidth=2,
|
||
markersize=8,
|
||
)
|
||
ax.set_xlabel("Number of Points", fontsize=12)
|
||
ax.set_ylabel(title, fontsize=12)
|
||
ax.set_title(f"{title} vs Number of Points", fontsize=14)
|
||
ax.grid(True, alpha=0.3)
|
||
ax.set_xticks(xs)
|
||
plt.tight_layout()
|
||
plot_path = summary_path.with_name("performance_comparison.png")
|
||
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
|
||
|
||
def main() -> None:
|
||
logging.basicConfig(level=logging.INFO)
|
||
args = parse_args()
|
||
if args.task_file:
|
||
task = load_task_from_toml(args.task_file)
|
||
TaskRunner(task).run()
|
||
return
|
||
|
||
comparison_data: Dict[int, pd.DataFrame] = {}
|
||
for num_points in args.point_configs:
|
||
output_dir = default_output_dir(num_points, args.per_component)
|
||
task = build_task_for_points(args, num_points, output_dir)
|
||
if not task.steps:
|
||
continue
|
||
TaskRunner(task).run()
|
||
if not args.skip_comparison and not args.skip_evaluation:
|
||
df = load_results_csv(output_dir)
|
||
if df is not None:
|
||
comparison_data[num_points] = df
|
||
if not args.skip_comparison and comparison_data:
|
||
compare_results(comparison_data, args.comparison_dir)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|