#!/usr/bin/env python3 """ SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本) """ import argparse import logging from dataclasses import dataclass from typing import List, Optional from src.tasks.config import TaskConfig, TaskStepConfig from src.tasks.io import load_task_from_toml from src.tasks.pipeline import TaskRunner @dataclass class BBoxCLIArgs: data_root: str test_file: str model_id: str output_dir: str expand_ratio: float num_vis: int vis_all: bool skip_inference: bool skip_evaluation: bool skip_visualization: bool config_name: str task_file: Optional[str] def parse_args() -> BBoxCLIArgs: parser = argparse.ArgumentParser( description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估" ) parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录") parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径") parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID") parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录") parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)") parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量") parser.add_argument("--vis_all", action="store_true", help="可视化所有样本") parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤") parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤") parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤") parser.add_argument( "--config_name", type=str, default="sam2_bbox_prompt", help="ProjectConfig 名称(来自 ConfigRegistry)", ) parser.add_argument( "--task_file", type=str, default=None, help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)", ) args = parser.parse_args() return BBoxCLIArgs( data_root=args.data_root, test_file=args.test_file, model_id=args.model_id, output_dir=args.output_dir, expand_ratio=args.expand_ratio, num_vis=args.num_vis, vis_all=args.vis_all, skip_inference=args.skip_inference, skip_evaluation=args.skip_evaluation, skip_visualization=args.skip_visualization, config_name=args.config_name, task_file=args.task_file, ) def build_cli_task(args: BBoxCLIArgs) -> TaskConfig: steps: List[TaskStepConfig] = [] common = { "data_root": args.data_root, "test_file": args.test_file, "model_id": args.model_id, "output_dir": args.output_dir, } if not args.skip_inference: steps.append( TaskStepConfig( kind="bbox_inference", params={**common, "expand_ratio": args.expand_ratio}, ) ) if not args.skip_evaluation: steps.append( TaskStepConfig( kind="legacy_evaluation", params={ **common, "pred_dir": f"{args.output_dir}/predictions", "compute_skeleton": True, }, ) ) if not args.skip_visualization: steps.append( TaskStepConfig( kind="legacy_visualization", params={ **common, "pred_dir": f"{args.output_dir}/predictions", "results_csv": f"{args.output_dir}/evaluation_results.csv", "num_samples": args.num_vis, "save_all": args.vis_all, "create_metrics_plot": True, }, ) ) return TaskConfig( name="bbox_cli_run", description="Legacy bbox prompt pipeline executed via TaskRunner", project_config_name=args.config_name, steps=steps, ) def main() -> None: logging.basicConfig(level=logging.INFO) args = parse_args() if args.task_file: task = load_task_from_toml(args.task_file) else: task = build_cli_task(args) if not task.steps: raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.") runner = TaskRunner(task) runner.run() if __name__ == "__main__": main()