sam_crack/run_bbox_evaluation.py
2025-12-24 13:43:34 +08:00

138 lines
4.5 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
"""
import argparse
import logging
from dataclasses import dataclass
from typing import List, Optional
from src.tasks.config import TaskConfig, TaskStepConfig
from src.tasks.io import load_task_from_toml
from src.tasks.pipeline import TaskRunner
@dataclass
class BBoxCLIArgs:
data_root: str
test_file: str
model_id: str
output_dir: str
expand_ratio: float
num_vis: int
vis_all: bool
skip_inference: bool
skip_evaluation: bool
skip_visualization: bool
config_name: str
task_file: Optional[str]
def parse_args() -> BBoxCLIArgs:
parser = argparse.ArgumentParser(
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
)
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
parser.add_argument(
"--config_name",
type=str,
default="sam2_bbox_prompt",
help="ProjectConfig 名称(来自 ConfigRegistry",
)
parser.add_argument(
"--task_file",
type=str,
default=None,
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
)
args = parser.parse_args()
return BBoxCLIArgs(
data_root=args.data_root,
test_file=args.test_file,
model_id=args.model_id,
output_dir=args.output_dir,
expand_ratio=args.expand_ratio,
num_vis=args.num_vis,
vis_all=args.vis_all,
skip_inference=args.skip_inference,
skip_evaluation=args.skip_evaluation,
skip_visualization=args.skip_visualization,
config_name=args.config_name,
task_file=args.task_file,
)
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
steps: List[TaskStepConfig] = []
common = {
"data_root": args.data_root,
"test_file": args.test_file,
"model_id": args.model_id,
"output_dir": args.output_dir,
}
if not args.skip_inference:
steps.append(
TaskStepConfig(
kind="bbox_inference",
params={**common, "expand_ratio": args.expand_ratio},
)
)
if not args.skip_evaluation:
steps.append(
TaskStepConfig(
kind="legacy_evaluation",
params={
**common,
"pred_dir": f"{args.output_dir}/predictions",
"compute_skeleton": True,
},
)
)
if not args.skip_visualization:
steps.append(
TaskStepConfig(
kind="legacy_visualization",
params={
**common,
"pred_dir": f"{args.output_dir}/predictions",
"results_csv": f"{args.output_dir}/evaluation_results.csv",
"num_samples": args.num_vis,
"save_all": args.vis_all,
"create_metrics_plot": True,
},
)
)
return TaskConfig(
name="bbox_cli_run",
description="Legacy bbox prompt pipeline executed via TaskRunner",
project_config_name=args.config_name,
steps=steps,
)
def main() -> None:
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.task_file:
task = load_task_from_toml(args.task_file)
else:
task = build_cli_task(args)
if not task.steps:
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
runner = TaskRunner(task)
runner.run()
if __name__ == "__main__":
main()