revert to oldest implemention

This commit is contained in:
Dustella 2025-12-24 14:58:07 +08:00
parent 4886fc8861
commit 5371973442
No known key found for this signature in database
GPG Key ID: C6227AE4A45E0187
46 changed files with 815 additions and 2566 deletions

View File

@ -1,35 +0,0 @@
{
"crop_size": null,
"data_format": "channels_first",
"default_to_square": true,
"device": null,
"disable_grouping": null,
"do_center_crop": null,
"do_convert_rgb": true,
"do_normalize": false,
"do_rescale": false,
"do_resize": false,
"image_mean": [
0.485,
0.456,
0.406
],
"image_processor_type": "Sam2ImageProcessorFast",
"image_std": [
0.229,
0.224,
0.225
],
"input_data_format": null,
"mask_size": {
"height": 256,
"width": 256
},
"processor_class": "Sam2VideoProcessor",
"resample": 2,
"rescale_factor": 0.00392156862745098,
"return_tensors": null,
"size": {
"longest_edge": 1024
}
}

98
pixi.lock generated
View File

@ -34,6 +34,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
@ -48,8 +49,10 @@ environments:
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
@ -81,6 +84,7 @@ environments:
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
@ -88,6 +92,7 @@ environments:
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4f/87/424511bdcd02c8d7acf9f65caa09f291a519b16bd83c3fb3374b3d4ae951/pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
@ -122,6 +127,7 @@ environments:
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
- pypi: /home/dustella/projects/sam2
packages:
- conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726
@ -144,6 +150,12 @@ packages:
purls: []
size: 23621
timestamp: 1650670423406
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
name: antlr4-python3-runtime
version: 4.9.3
sha256: f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b
requires_dist:
- typing ; python_full_version < '3.5'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
name: asttokens
version: 3.0.1
@ -540,6 +552,15 @@ packages:
- types-tqdm ; extra == 'typing'
- types-urllib3 ; extra == 'typing'
requires_python: '>=3.8.0'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl
name: hydra-core
version: 1.3.2
sha256: fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b
requires_dist:
- omegaconf>=2.2,<2.4
- antlr4-python3-runtime==4.9.*
- packaging
- importlib-resources ; python_full_version < '3.9'
- conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.1-h33c6efd_0.conda
sha256: 7d6463d0be5092b2ae8f2fad34dc84de83eab8bd44cc0d4be8931881c973c48f
md5: 518e9bbbc3e3486d6a4519192ba690f8
@ -623,6 +644,17 @@ packages:
- sphinx<6 ; extra == 'full'
- tifffile ; extra == 'full'
requires_python: '>=3.9'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz
name: iopath
version: 0.1.10
sha256: 3311c16a4d9137223e20f141655759933e1eda24f8bff166af834af3c645ef01
requires_dist:
- tqdm
- typing-extensions
- portalocker
- dataclasses ; python_full_version < '3.7'
- boto3 ; extra == 'aws'
requires_python: '>=3.6'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
name: ipykernel
version: 7.1.0
@ -1174,6 +1206,15 @@ packages:
version: 12.8.90
sha256: 5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f
requires_python: '>=3'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
name: omegaconf
version: 2.3.0
sha256: 7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b
requires_dist:
- antlr4-python3-runtime==4.9.*
- pyyaml>=5.1.0
- dataclasses ; python_full_version == '3.6.*'
requires_python: '>=3.6'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
name: opencv-python
version: 4.12.0.88
@ -1355,6 +1396,25 @@ packages:
- pytest>=8.4.2 ; extra == 'test'
- mypy>=1.18.2 ; extra == 'type'
requires_python: '>=3.10'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl
name: portalocker
version: 3.2.0
sha256: 3cdc5f565312224bc570c49337bd21428bba0ef363bbcf58b9ef4a9f11779968
requires_dist:
- pywin32>=226 ; sys_platform == 'win32'
- portalocker[tests] ; extra == 'docs'
- coverage-conditional-plugin>=0.9.0 ; extra == 'tests'
- portalocker[redis] ; extra == 'tests'
- pytest-cov>=2.8.1 ; extra == 'tests'
- pytest-mypy>=0.8.0 ; extra == 'tests'
- pytest-rerunfailures>=15.0 ; extra == 'tests'
- pytest-timeout>=2.1.0 ; extra == 'tests'
- pytest>=5.4.1 ; extra == 'tests'
- sphinx>=6.0.0 ; extra == 'tests'
- types-pywin32>=310.0.0.20250429 ; extra == 'tests'
- types-redis ; extra == 'tests'
- redis ; extra == 'redis'
requires_python: '>=3.9'
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
name: prompt-toolkit
version: 3.0.52
@ -1544,6 +1604,44 @@ packages:
- safetensors[testing] ; extra == 'all'
- safetensors[all] ; extra == 'dev'
requires_python: '>=3.9'
- pypi: /home/dustella/projects/sam2
name: sam-2
version: '1.0'
sha256: f3cc0abfc266b6f57a457d26bdf52fcaa92567d8fa27ba6ff59ebe60bec0ac69
requires_dist:
- torch>=2.5.1
- torchvision>=0.20.1
- numpy>=1.24.4
- tqdm>=4.66.1
- hydra-core>=1.3.2
- iopath>=0.1.10
- pillow>=9.4.0
- matplotlib>=3.9.1 ; extra == 'notebooks'
- jupyter>=1.0.0 ; extra == 'notebooks'
- opencv-python>=4.7.0 ; extra == 'notebooks'
- eva-decord>=0.6.1 ; extra == 'notebooks'
- flask>=3.0.3 ; extra == 'interactive-demo'
- flask-cors>=5.0.0 ; extra == 'interactive-demo'
- av>=13.0.0 ; extra == 'interactive-demo'
- dataclasses-json>=0.6.7 ; extra == 'interactive-demo'
- eva-decord>=0.6.1 ; extra == 'interactive-demo'
- gunicorn>=23.0.0 ; extra == 'interactive-demo'
- imagesize>=1.4.1 ; extra == 'interactive-demo'
- pycocotools>=2.0.8 ; extra == 'interactive-demo'
- strawberry-graphql>=0.243.0 ; extra == 'interactive-demo'
- black==24.2.0 ; extra == 'dev'
- usort==1.0.2 ; extra == 'dev'
- ufmt==2.0.0b2 ; extra == 'dev'
- fvcore>=0.1.5.post20221221 ; extra == 'dev'
- pandas>=2.2.2 ; extra == 'dev'
- scikit-image>=0.24.0 ; extra == 'dev'
- tensorboard>=2.17.0 ; extra == 'dev'
- pycocotools>=2.0.8 ; extra == 'dev'
- tensordict>=0.6.0 ; extra == 'dev'
- opencv-python>=4.7.0 ; extra == 'dev'
- submitit>=1.5.1 ; extra == 'dev'
requires_python: '>=3.10.0'
editable: true
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f4/a2/70401a107d6d7466d64b466927e6b96fcefa99d57494b972608e2f8be50f/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
name: scikit-image
version: 0.26.0

View File

@ -26,3 +26,4 @@ tqdm = ">=4.65.0"
pandas = ">=2.0.0"
transformers = ">=4.57.3, <5"
ipykernel = ">=7.1.0, <8"
sam-2 = { path = "/home/dustella/projects/sam2", editable = true }

View File

@ -1,136 +1,248 @@
#!/usr/bin/env python3
"""
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
SAM2 边界框提示方式完整评估流程
包括推理 -> 评估 -> 可视化
"""
import os
import sys
import argparse
import logging
from dataclasses import dataclass
from typing import List, Optional
import time
from pathlib import Path
from src.tasks.config import TaskConfig, TaskStepConfig
from src.tasks.io import load_task_from_toml
from src.tasks.pipeline import TaskRunner
# 添加 src 目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor
from evaluation import evaluate_test_set
from visualization import visualize_test_set, create_metrics_distribution_plot
@dataclass
class BBoxCLIArgs:
data_root: str
test_file: str
model_id: str
output_dir: str
expand_ratio: float
num_vis: int
vis_all: bool
skip_inference: bool
skip_evaluation: bool
skip_visualization: bool
config_name: str
task_file: Optional[str]
def parse_args() -> BBoxCLIArgs:
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
description="SAM2 边界框提示方式 - Crack500 数据集完整评估"
)
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
# 数据集参数
parser.add_argument(
"--config_name",
type=str,
default="sam2_bbox_prompt",
help="ProjectConfig 名称(来自 ConfigRegistry",
"--data_root", type=str, default="./crack500",
help="数据集根目录"
)
parser.add_argument(
"--task_file",
type=str,
default=None,
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
)
args = parser.parse_args()
return BBoxCLIArgs(
data_root=args.data_root,
test_file=args.test_file,
model_id=args.model_id,
output_dir=args.output_dir,
expand_ratio=args.expand_ratio,
num_vis=args.num_vis,
vis_all=args.vis_all,
skip_inference=args.skip_inference,
skip_evaluation=args.skip_evaluation,
skip_visualization=args.skip_visualization,
config_name=args.config_name,
task_file=args.task_file,
"--test_file", type=str, default="./crack500/test.txt",
help="测试集文件路径"
)
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
steps: List[TaskStepConfig] = []
common = {
"data_root": args.data_root,
"test_file": args.test_file,
"model_id": args.model_id,
"output_dir": args.output_dir,
}
if not args.skip_inference:
steps.append(
TaskStepConfig(
kind="bbox_inference",
params={**common, "expand_ratio": args.expand_ratio},
)
)
if not args.skip_evaluation:
steps.append(
TaskStepConfig(
kind="legacy_evaluation",
params={
**common,
"pred_dir": f"{args.output_dir}/predictions",
"compute_skeleton": True,
},
)
)
if not args.skip_visualization:
steps.append(
TaskStepConfig(
kind="legacy_visualization",
params={
**common,
"pred_dir": f"{args.output_dir}/predictions",
"results_csv": f"{args.output_dir}/evaluation_results.csv",
"num_samples": args.num_vis,
"save_all": args.vis_all,
"create_metrics_plot": True,
},
)
)
return TaskConfig(
name="bbox_cli_run",
description="Legacy bbox prompt pipeline executed via TaskRunner",
project_config_name=args.config_name,
steps=steps,
# 模型参数
parser.add_argument(
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
help="SAM2 模型检查点路径"
)
parser.add_argument(
"--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml",
help="SAM2 模型配置文件"
)
# 输出参数
parser.add_argument(
"--output_dir", type=str, default="./results/bbox_prompt",
help="输出目录"
)
def main() -> None:
logging.basicConfig(level=logging.INFO)
# 边界框参数
parser.add_argument(
"--expand_ratio", type=float, default=0.05,
help="边界框扩展比例 (0.0-1.0)"
)
# 可视化参数
parser.add_argument(
"--num_vis", type=int, default=20,
help="可视化样本数量"
)
parser.add_argument(
"--vis_all", action="store_true",
help="可视化所有样本"
)
# 流程控制
parser.add_argument(
"--skip_inference", action="store_true",
help="跳过推理步骤(使用已有预测结果)"
)
parser.add_argument(
"--skip_evaluation", action="store_true",
help="跳过评估步骤"
)
parser.add_argument(
"--skip_visualization", action="store_true",
help="跳过可视化步骤"
)
return parser.parse_args()
def main():
"""主函数"""
args = parse_args()
if args.task_file:
task = load_task_from_toml(args.task_file)
print("=" * 80)
print("SAM2 边界框提示方式 - Crack500 数据集完整评估")
print("=" * 80)
print(f"数据集根目录: {args.data_root}")
print(f"测试集文件: {args.test_file}")
print(f"模型检查点: {args.checkpoint}")
print(f"模型配置: {args.model_cfg}")
print(f"边界框扩展比例: {args.expand_ratio * 100}%")
print(f"输出目录: {args.output_dir}")
print("=" * 80)
# 检查必要文件
if not os.path.exists(args.data_root):
print(f"\n错误: 数据集目录不存在 {args.data_root}")
return
if not os.path.exists(args.test_file):
print(f"\n错误: 测试集文件不存在 {args.test_file}")
return
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# ========== 步骤 1: 推理 ==========
if not args.skip_inference:
print("\n" + "=" * 80)
print("步骤 1/3: 使用 SAM2 进行推理")
print("=" * 80)
# 检查模型文件
if not os.path.exists(args.checkpoint):
print(f"\n错误: 模型检查点不存在 {args.checkpoint}")
print("请先下载 SAM2 模型权重!")
print("运行: cd sam2/checkpoints && ./download_ckpts.sh")
return
try:
import torch
# 检查 CUDA
if not torch.cuda.is_available():
print("警告: CUDA 不可用,将使用 CPU速度会很慢")
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 加载模型
print("\n加载 SAM2 模型...")
start_time = time.time()
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
predictor = SAM2ImagePredictor(sam2_model)
print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s")
# 处理测试集
print("\n开始推理...")
start_time = time.time()
results = process_test_set(
data_root=args.data_root,
test_file=args.test_file,
predictor=predictor,
output_dir=args.output_dir,
expand_ratio=args.expand_ratio
)
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
print(f"成功处理 {len(results)} 张图像")
except Exception as e:
print(f"\n推理过程出错: {str(e)}")
import traceback
traceback.print_exc()
return
else:
task = build_cli_task(args)
if not task.steps:
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
runner = TaskRunner(task)
runner.run()
print("\n跳过推理步骤(使用已有预测结果)")
# ========== 步骤 2: 评估 ==========
if not args.skip_evaluation:
print("\n" + "=" * 80)
print("步骤 2/3: 评估预测结果")
print("=" * 80)
pred_dir = os.path.join(args.output_dir, "predictions")
if not os.path.exists(pred_dir):
print(f"\n错误: 预测目录不存在 {pred_dir}")
print("请先运行推理步骤!")
return
try:
start_time = time.time()
df_results = evaluate_test_set(
data_root=args.data_root,
test_file=args.test_file,
pred_dir=pred_dir,
output_dir=args.output_dir,
compute_skeleton=True
)
print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s")
except Exception as e:
print(f"\n评估过程出错: {str(e)}")
import traceback
traceback.print_exc()
return
else:
print("\n跳过评估步骤")
# ========== 步骤 3: 可视化 ==========
if not args.skip_visualization:
print("\n" + "=" * 80)
print("步骤 3/3: 生成可视化结果")
print("=" * 80)
pred_dir = os.path.join(args.output_dir, "predictions")
results_csv = os.path.join(args.output_dir, "evaluation_results.csv")
if not os.path.exists(pred_dir):
print(f"\n错误: 预测目录不存在 {pred_dir}")
return
try:
start_time = time.time()
# 可视化样本
visualize_test_set(
data_root=args.data_root,
test_file=args.test_file,
pred_dir=pred_dir,
output_dir=args.output_dir,
results_csv=results_csv if os.path.exists(results_csv) else None,
num_samples=args.num_vis,
save_all=args.vis_all
)
# 创建指标分布图
if os.path.exists(results_csv):
create_metrics_distribution_plot(results_csv, args.output_dir)
print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s")
except Exception as e:
print(f"\n可视化过程出错: {str(e)}")
import traceback
traceback.print_exc()
return
else:
print("\n跳过可视化步骤")
# ========== 完成 ==========
print("\n" + "=" * 80)
print("所有步骤完成!")
print("=" * 80)
print(f"\n结果保存在: {args.output_dir}")
print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}")
print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}")
print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}")
print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}")
print("=" * 80)
if __name__ == "__main__":

View File

@ -1,222 +1,272 @@
#!/usr/bin/env python3
"""
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
SAM2 点提示方式完整评估流程
支持 1, 3, 5 个点的对比实验
"""
import argparse
import logging
import os
from dataclasses import dataclass
import sys
import argparse
import time
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
# 添加 src 目录到路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from src.tasks.config import TaskConfig, TaskStepConfig
from src.tasks.io import load_task_from_toml
from src.tasks.pipeline import TaskRunner
from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor
from evaluation import evaluate_test_set
from visualization import visualize_test_set, create_metrics_distribution_plot
@dataclass
class PointCLIArgs:
data_root: str
test_file: str
model_id: str
point_configs: List[int]
per_component: bool
num_vis: int
skip_inference: bool
skip_evaluation: bool
skip_visualization: bool
skip_comparison: bool
comparison_dir: str
config_name: str
task_file: Optional[str]
def run_single_experiment(
data_root: str,
test_file: str,
checkpoint: str,
model_cfg: str,
num_points: int,
per_component: bool = False
):
"""运行单个点数的实验"""
def parse_args() -> PointCLIArgs:
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
parser.add_argument(
"--config_name",
type=str,
default="sam2_bbox_prompt",
help="ProjectConfig 名称(来自 ConfigRegistry",
)
parser.add_argument(
"--task_file",
type=str,
default=None,
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
)
args = parser.parse_args()
return PointCLIArgs(
data_root=args.data_root,
test_file=args.test_file,
model_id=args.model_id,
point_configs=args.point_configs,
per_component=args.per_component,
num_vis=args.num_vis,
skip_inference=args.skip_inference,
skip_evaluation=args.skip_evaluation,
skip_visualization=args.skip_visualization,
skip_comparison=args.skip_comparison,
comparison_dir=args.comparison_dir,
config_name=args.config_name,
task_file=args.task_file,
)
def default_output_dir(num_points: int, per_component: bool) -> str:
# 设置输出目录
if per_component:
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
return f"./results/point_prompt_{num_points}pts_hf"
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
else:
output_dir = f"./results/point_prompt_{num_points}pts"
print("\n" + "=" * 80)
print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})")
print("=" * 80)
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
steps: List[TaskStepConfig] = []
common = {
"data_root": args.data_root,
"test_file": args.test_file,
"model_id": args.model_id,
"output_dir": output_dir,
}
if not args.skip_inference:
steps.append(
TaskStepConfig(
kind="point_inference",
params={
**common,
"num_points": num_points,
"per_component": args.per_component,
},
)
)
if not args.skip_evaluation:
steps.append(
TaskStepConfig(
kind="legacy_evaluation",
params={
**common,
"pred_dir": f"{output_dir}/predictions",
"compute_skeleton": True,
},
)
)
if not args.skip_visualization:
steps.append(
TaskStepConfig(
kind="legacy_visualization",
params={
**common,
"pred_dir": f"{output_dir}/predictions",
"results_csv": f"{output_dir}/evaluation_results.csv",
"num_samples": args.num_vis,
"save_all": False,
"create_metrics_plot": True,
},
)
)
return TaskConfig(
name=f"point_cli_{num_points}",
description=f"Legacy point prompt pipeline ({num_points} pts)",
project_config_name=args.config_name,
steps=steps,
# 加载模型
print("\n加载 SAM2 模型...")
import torch
sam2_model = build_sam2(model_cfg, checkpoint)
predictor = SAM2ImagePredictor(sam2_model)
print("模型加载完成!")
# 推理
print(f"\n步骤 1/3: 推理 ({num_points} 个点)")
start_time = time.time()
results = process_test_set(
data_root=data_root,
test_file=test_file,
predictor=predictor,
output_dir=output_dir,
num_points=num_points,
per_component=per_component
)
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
# 评估
print(f"\n步骤 2/3: 评估")
pred_dir = os.path.join(output_dir, "predictions")
start_time = time.time()
df_results = evaluate_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
compute_skeleton=True
)
print(f"评估完成!耗时: {time.time() - start_time:.2f}s")
# 可视化
print(f"\n步骤 3/3: 可视化")
results_csv = os.path.join(output_dir, "evaluation_results.csv")
start_time = time.time()
visualize_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
results_csv=results_csv,
num_samples=10,
save_all=False
)
create_metrics_distribution_plot(results_csv, output_dir)
print(f"可视化完成!耗时: {time.time() - start_time:.2f}s")
return df_results
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
csv_path = Path(output_dir) / "evaluation_results.csv"
if not csv_path.exists():
return None
return pd.read_csv(csv_path)
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
if not results:
return
os.makedirs(output_dir, exist_ok=True)
summary_rows = []
for num_points, df in results.items():
summary_rows.append(
{
"num_points": num_points,
"iou_mean": df["iou"].mean(),
"iou_std": df["iou"].std(),
"dice_mean": df["dice"].mean(),
"dice_std": df["dice"].std(),
"f1_mean": df["f1_score"].mean(),
"f1_std": df["f1_score"].std(),
"precision_mean": df["precision"].mean(),
"recall_mean": df["recall"].mean(),
}
)
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
summary_path.parent.mkdir(parents=True, exist_ok=True)
df_summary.to_csv(summary_path, index=False)
def compare_results(results_dict: dict, output_dir: str = "./results"):
"""对比不同点数的结果"""
import pandas as pd
import matplotlib.pyplot as plt
metrics_to_plot = [
("iou_mean", "iou_std", "IoU"),
("dice_mean", "dice_std", "Dice"),
("f1_mean", "f1_std", "F1-Score"),
]
print("\n" + "=" * 80)
print("对比不同点数的性能")
print("=" * 80)
# 收集所有结果
summary = []
for num_points, df in results_dict.items():
metrics = {
'num_points': num_points,
'iou_mean': df['iou'].mean(),
'iou_std': df['iou'].std(),
'dice_mean': df['dice'].mean(),
'dice_std': df['dice'].std(),
'f1_mean': df['f1_score'].mean(),
'f1_std': df['f1_score'].std(),
'precision_mean': df['precision'].mean(),
'recall_mean': df['recall'].mean(),
}
summary.append(metrics)
df_summary = pd.DataFrame(summary)
# 打印对比表格
print("\n性能对比:")
print(df_summary.to_string(index=False))
# 保存对比结果
comparison_dir = os.path.join(output_dir, "point_comparison")
os.makedirs(comparison_dir, exist_ok=True)
csv_path = os.path.join(comparison_dir, "comparison_summary.csv")
df_summary.to_csv(csv_path, index=False)
print(f"\n对比结果已保存到: {csv_path}")
# 绘制对比图
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
xs = df_summary["num_points"].tolist()
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
ax.errorbar(
xs,
df_summary[mean_col],
yerr=df_summary[std_col],
marker="o",
capsize=5,
linewidth=2,
markersize=8,
)
ax.set_xlabel("Number of Points", fontsize=12)
metrics_to_plot = [
('iou_mean', 'iou_std', 'IoU'),
('dice_mean', 'dice_std', 'Dice'),
('f1_mean', 'f1_std', 'F1-Score')
]
for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot):
ax = axes[idx]
x = df_summary['num_points']
y = df_summary[mean_col]
yerr = df_summary[std_col]
ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8)
ax.set_xlabel('Number of Points', fontsize=12)
ax.set_ylabel(title, fontsize=12)
ax.set_title(f"{title} vs Number of Points", fontsize=14)
ax.set_title(f'{title} vs Number of Points', fontsize=14)
ax.grid(True, alpha=0.3)
ax.set_xticks(xs)
ax.set_xticks(x)
plt.tight_layout()
plot_path = summary_path.with_name("performance_comparison.png")
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
plot_path = os.path.join(comparison_dir, "performance_comparison.png")
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f"对比图已保存到: {plot_path}")
def main() -> None:
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.task_file:
task = load_task_from_toml(args.task_file)
TaskRunner(task).run()
return
return df_summary
def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="SAM2 点提示方式 - 多点数对比实验"
)
# 数据集参数
parser.add_argument(
"--data_root", type=str, default="./crack500",
help="数据集根目录"
)
parser.add_argument(
"--test_file", type=str, default="./crack500/test.txt",
help="测试集文件路径"
)
# 模型参数
parser.add_argument(
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
help="SAM2 模型检查点路径"
)
parser.add_argument(
"--model_cfg", type=str, default="sam2.1_hiera_s.yaml",
help="SAM2 模型配置文件"
)
# 实验参数
parser.add_argument(
"--point_configs", type=int, nargs='+', default=[1, 3, 5],
help="要测试的点数配置"
)
parser.add_argument(
"--per_component", action="store_true",
help="为每个连通域独立采样"
)
parser.add_argument(
"--skip_comparison", action="store_true",
help="跳过对比分析"
)
args = parser.parse_args()
print("=" * 80)
print("SAM2 点提示方式 - 多点数对比实验")
print("=" * 80)
print(f"数据集根目录: {args.data_root}")
print(f"测试集文件: {args.test_file}")
print(f"模型检查点: {args.checkpoint}")
print(f"点数配置: {args.point_configs}")
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
print("=" * 80)
# 检查 CUDA
import torch
if not torch.cuda.is_available():
print("警告: CUDA 不可用,将使用 CPU速度会很慢")
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 运行所有实验
results_dict = {}
comparison_data: Dict[int, pd.DataFrame] = {}
for num_points in args.point_configs:
output_dir = default_output_dir(num_points, args.per_component)
task = build_task_for_points(args, num_points, output_dir)
if not task.steps:
try:
df_results = run_single_experiment(
data_root=args.data_root,
test_file=args.test_file,
checkpoint=args.checkpoint,
model_cfg=args.model_cfg,
num_points=num_points,
per_component=args.per_component
)
results_dict[num_points] = df_results
except Exception as e:
print(f"\n实验失败 ({num_points} 个点): {str(e)}")
import traceback
traceback.print_exc()
continue
TaskRunner(task).run()
if not args.skip_comparison and not args.skip_evaluation:
df = load_results_csv(output_dir)
if df is not None:
comparison_data[num_points] = df
if not args.skip_comparison and comparison_data:
compare_results(comparison_data, args.comparison_dir)
# 对比分析
if not args.skip_comparison and len(results_dict) > 1:
try:
compare_results(results_dict)
except Exception as e:
print(f"\n对比分析失败: {str(e)}")
import traceback
traceback.print_exc()
# 完成
print("\n" + "=" * 80)
print("所有实验完成!")
print("=" * 80)
print("\n结果目录:")
for num_points in args.point_configs:
if args.per_component:
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
else:
output_dir = f"./results/point_prompt_{num_points}pts"
print(f" - {num_points} 个点: {output_dir}")
if not args.skip_comparison and len(results_dict) > 1:
print(f" - 对比分析: ./results/point_comparison")
print("=" * 80)
if __name__ == "__main__":

View File

@ -1,26 +1,145 @@
"""
边界框提示方式的 SAM2 裂缝分割实现使用 HuggingFace Transformers
边界框提示方式的 SAM2 裂缝分割实现
GT 掩码中提取边界框使用 SAM2 进行分割
"""
import os
import cv2
import numpy as np
import torch
from pathlib import Path
from typing import Dict, List
from typing import List, Tuple, Dict
from tqdm import tqdm
import json
import cv2
from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask
from .hf_sam2_predictor import HFSam2Predictor
from .model.inference import predict_with_bbox_prompt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
def extract_bboxes_from_mask(mask: np.ndarray, expand_ratio: float = 0.0) -> List[np.ndarray]:
"""
从二值掩码中提取所有连通域的边界框
Args:
mask: 二值掩码 (H, W)值为 0 255
expand_ratio: 边界框扩展比例例如 0.05 表示扩展 5%
Returns:
List of bounding boxes in format [x1, y1, x2, y2]
"""
# 确保掩码是二值的
binary_mask = (mask > 0).astype(np.uint8)
# 连通域分析
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
binary_mask, connectivity=8
)
bboxes = []
# 跳过背景 (label 0)
for i in range(1, num_labels):
x, y, w, h, area = stats[i]
# 过滤太小的区域(可能是噪声)
if area < 10: # 最小面积阈值
continue
# 计算边界框
x1, y1 = x, y
x2, y2 = x + w, y + h
# 扩展边界框(如果需要)
if expand_ratio > 0:
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
w_new = w * (1 + expand_ratio)
h_new = h * (1 + expand_ratio)
x1 = max(0, int(cx - w_new / 2))
y1 = max(0, int(cy - h_new / 2))
x2 = min(mask.shape[1], int(cx + w_new / 2))
y2 = min(mask.shape[0], int(cy + h_new / 2))
bboxes.append(np.array([x1, y1, x2, y2]))
return bboxes
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
"""
加载图像和掩码
Args:
image_path: 图像路径
mask_path: 掩码路径
Returns:
image: RGB 图像 (H, W, 3)
mask: 二值掩码 (H, W)
"""
# 加载图像
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法加载图像: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# 加载掩码
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"无法加载掩码: {mask_path}")
return image, mask
def predict_with_bbox_prompt(
predictor: SAM2ImagePredictor,
image: np.ndarray,
bboxes: List[np.ndarray]
) -> np.ndarray:
"""
使用边界框提示进行 SAM2 预测
Args:
predictor: SAM2ImagePredictor 实例
image: RGB 图像 (H, W, 3)
bboxes: 边界框列表每个为 [x1, y1, x2, y2]
Returns:
combined_mask: 合并后的预测掩码 (H, W)
"""
# 设置图像
predictor.set_image(image)
# 如果没有边界框,返回空掩码
if len(bboxes) == 0:
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
# 合并所有预测掩码
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
# 对每个边界框进行预测
for bbox in bboxes:
masks, scores, logits = predictor.predict(
point_coords=None,
point_labels=None,
box=bbox[None, :], # shape: (1, 4)
multimask_output=False,
)
# 取第一个掩码(因为 multimask_output=False
mask = masks[0] # shape: (H, W)
# 合并到总掩码中
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
# 转换为 0-255
combined_mask = combined_mask * 255
return combined_mask
def process_test_set(
data_root: str,
test_file: str,
predictor: HFSam2Predictor,
predictor: SAM2ImagePredictor,
output_dir: str,
expand_ratio: float = 0.0
) -> List[Dict]:
@ -30,7 +149,7 @@ def process_test_set(
Args:
data_root: 数据集根目录
test_file: 测试集文件路径 (test.txt)
predictor: HFSam2Predictor 实例
predictor: SAM2ImagePredictor 实例
output_dir: 输出目录
expand_ratio: 边界框扩展比例
@ -77,7 +196,7 @@ def process_test_set(
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
# 使用 SAM2 预测
with torch.inference_mode():
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
# 保存预测掩码
@ -96,9 +215,6 @@ def process_test_set(
except Exception as e:
print(f"处理失败 {img_path}: {str(e)}")
# print stack trace
import traceback
traceback.print_exc()
continue
# 保存结果信息
@ -118,20 +234,22 @@ def main():
# 配置参数
DATA_ROOT = "./crack500"
TEST_FILE = "./crack500/test.txt"
OUTPUT_DIR = "./results/bbox_prompt_hf"
OUTPUT_DIR = "./results/bbox_prompt"
# HuggingFace SAM2 模型
MODEL_ID = "facebook/sam2-hiera-small"
# SAM2 模型配置
CHECKPOINT = "./sam2/checkpoints/sam2.1_hiera_small.pt"
MODEL_CFG = "sam2.1_hiera_s.yaml"
# 边界框扩展比例
EXPAND_RATIO = 0.05 # 5% 扩展
print("=" * 60)
print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估")
print("SAM2 边界框提示方式 - Crack500 数据集评估")
print("=" * 60)
print(f"数据集根目录: {DATA_ROOT}")
print(f"测试集文件: {TEST_FILE}")
print(f"模型: {MODEL_ID}")
print(f"模型检查点: {CHECKPOINT}")
print(f"模型配置: {MODEL_CFG}")
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
print(f"输出目录: {OUTPUT_DIR}")
print("=" * 60)
@ -142,10 +260,10 @@ def main():
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 构建 SAM2 predictor
# 构建 SAM2 模型
print("\n加载 SAM2 模型...")
from .hf_sam2_predictor import build_hf_sam2_predictor
predictor = build_hf_sam2_predictor(model_id=MODEL_ID)
sam2_model = build_sam2(MODEL_CFG, CHECKPOINT)
predictor = SAM2ImagePredictor(sam2_model)
print("模型加载完成!")
# 处理测试集

View File

@ -1,16 +0,0 @@
from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples
from .registry import DatasetRegistry
from .utils import extract_bboxes_from_mask, load_image_and_mask
# ensure built-in datasets register themselves
from . import crack500 # noqa: F401
__all__ = [
"BaseDataset",
"DatasetRecord",
"ModelReadySample",
"collate_samples",
"DatasetRegistry",
"extract_bboxes_from_mask",
"load_image_and_mask",
]

View File

@ -1,167 +0,0 @@
from __future__ import annotations
import abc
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from ..model_configuration.config import DatasetConfig
@dataclass
class DatasetRecord:
"""
Lightweight description of a single sample on disk.
"""
image_path: Path
mask_path: Optional[Path] = None
prompt_path: Optional[Path] = None
metadata: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelReadySample:
"""
Standard container that mirrors what Hugging Face pipelines expect.
"""
pixel_values: torch.Tensor | np.ndarray
prompts: Dict[str, Any] = field(default_factory=dict)
labels: Dict[str, Any] = field(default_factory=dict)
metadata: Dict[str, Any] = field(default_factory=dict)
def to_hf_dict(self) -> Dict[str, Any]:
payload = {
"pixel_values": self.pixel_values,
"metadata": self.metadata,
}
if self.prompts:
payload["prompts"] = self.prompts
if self.labels:
payload["labels"] = self.labels
return payload
class BaseDataset(Dataset):
"""
Common dataset base class that handles record bookkeeping, IO, and
formatting tensors for Hugging Face pipelines.
"""
dataset_name: str = "base"
def __init__(
self,
config: DatasetConfig,
transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None,
return_hf_dict: bool = True,
) -> None:
self.config = config
self.transforms = transforms
self.return_hf_dict = return_hf_dict
self.records: List[DatasetRecord] = self.load_records()
def __len__(self) -> int:
return len(self.records)
def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample:
record = self.records[index]
sample = self.prepare_sample(record)
if self.transforms:
sample = self.transforms(sample)
return sample.to_hf_dict() if self.return_hf_dict else sample
@abc.abstractmethod
def load_records(self) -> List[DatasetRecord]:
"""
Scan the dataset directory / annotation files and return
structured references to each item on disk.
"""
def prepare_sample(self, record: DatasetRecord) -> ModelReadySample:
"""
Load image/mask/prompt data from disk and wrap it inside ModelReadySample.
Subclasses can override this to implement custom augmentations or prompt generation.
"""
image = self._load_image(record.image_path)
mask = (
self._load_mask(record.mask_path)
if record.mask_path is not None
else None
)
prompts = self.build_prompts(record, mask)
labels = {"mask": mask} if mask is not None else {}
sample = ModelReadySample(
pixel_values=image,
prompts=prompts,
labels=labels,
metadata=record.metadata,
)
return sample
def build_prompts(
self, record: DatasetRecord, mask: Optional[np.ndarray]
) -> Dict[str, Any]:
"""
Derive prompts from metadata or masks.
Default implementation extracts bounding boxes from masks.
"""
if mask is None:
return {}
boxes = self._mask_to_bboxes(mask)
return {"boxes": boxes}
def _load_image(self, path: Path) -> np.ndarray:
image = Image.open(path).convert("RGB")
return np.array(image)
def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]:
if path is None:
return None
mask = Image.open(path).convert("L")
return np.array(mask)
def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]:
"""
Helper that mirrors the legacy bbox extraction pipeline.
"""
if mask.ndim != 2:
raise ValueError("Mask must be 2-dimensional.")
ys, xs = np.where(mask > 0)
if ys.size == 0:
return []
x_min, x_max = xs.min(), xs.max()
y_min, y_max = ys.min(), ys.max()
return [[int(x_min), int(y_min), int(x_max), int(y_max)]]
def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]:
"""
Default collate_fn that merges ModelReadySample/HF dict outputs.
"""
pixel_values = []
prompts: List[Dict[str, Any]] = []
labels: List[Dict[str, Any]] = []
metadata: List[Dict[str, Any]] = []
for item in batch:
if isinstance(item, ModelReadySample):
payload = item.to_hf_dict()
else:
payload = item
pixel_values.append(payload["pixel_values"])
prompts.append(payload.get("prompts", {}))
labels.append(payload.get("labels", {}))
metadata.append(payload.get("metadata", {}))
stacked = {
"pixel_values": torch.as_tensor(np.stack(pixel_values)),
"prompts": prompts,
"labels": labels,
"metadata": metadata,
}
return stacked

View File

@ -1,99 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from .base import BaseDataset, DatasetRecord
from .registry import DatasetRegistry
from .utils import (
extract_bboxes_from_mask,
sample_points_on_skeleton,
sample_points_per_component,
)
from ..model_configuration.config import DatasetConfig
@DatasetRegistry.register("crack500")
class Crack500Dataset(BaseDataset):
"""
Reference implementation that loads Crack500 samples from an image list.
"""
def __init__(
self,
config: DatasetConfig,
expand_ratio: float = 0.05,
min_area: int = 10,
**kwargs,
) -> None:
extra = dict(config.extra_params or {})
expand_ratio = float(extra.get("expand_ratio", expand_ratio))
self.prompt_mode = extra.get("prompt_mode", "bbox")
self.num_points = int(extra.get("num_points", 5))
self.per_component = bool(extra.get("per_component", False))
self.expand_ratio = expand_ratio
self.min_area = min_area
super().__init__(config, **kwargs)
def load_records(self) -> List[DatasetRecord]:
base_dir = Path(self.config.data_root)
list_file = (
Path(self.config.annotation_file)
if self.config.annotation_file
else base_dir / (self.config.split_file or "test.txt")
)
if not list_file.exists():
raise FileNotFoundError(f"Missing Crack500 split file: {list_file}")
image_dir = base_dir / (self.config.image_folder or "testcrop")
mask_dir = base_dir / (self.config.mask_folder or "testdata")
records: List[DatasetRecord] = []
with list_file.open("r", encoding="utf-8") as handle:
for line in handle:
image_name = line.strip()
if not image_name:
continue
image_path = image_dir / image_name
mask_name = image_name.replace(".jpg", ".png")
mask_path = mask_dir / mask_name
metadata = {"split": self.config.split, "image_name": image_name}
records.append(
DatasetRecord(
image_path=image_path,
mask_path=mask_path if mask_path.exists() else None,
metadata=metadata,
)
)
if not records:
raise RuntimeError(
f"No records found in {image_dir} for split {self.config.split}"
)
return records
def build_prompts(
self,
record: DatasetRecord,
mask: Optional[np.ndarray],
) -> Dict[str, List[List[int]]]:
if mask is None:
return {}
if self.prompt_mode == "point":
points, point_labels = self._build_point_prompts(mask)
if points.size == 0:
return {}
prompts: Dict[str, List[List[int]]] = {"points": points.tolist()}
if point_labels.size > 0:
prompts["point_labels"] = point_labels.tolist()
return prompts
boxes = extract_bboxes_from_mask(
mask, expand_ratio=self.expand_ratio, min_area=self.min_area
)
return {"boxes": boxes}
def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.per_component:
return sample_points_per_component(mask, self.num_points)
points = sample_points_on_skeleton(mask, self.num_points)
labels = np.ones(points.shape[0], dtype=np.int32)
return points, labels

View File

@ -1,33 +0,0 @@
from __future__ import annotations
from typing import Dict, Type
from .base import BaseDataset
class DatasetRegistry:
"""
Simple registry so configs can refer to datasets by string key.
"""
_registry: Dict[str, Type[BaseDataset]] = {}
@classmethod
def register(cls, name: str):
def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]:
cls._registry[name] = dataset_cls
dataset_cls.dataset_name = name
return dataset_cls
return decorator
@classmethod
def create(cls, name: str, *args, **kwargs) -> BaseDataset:
if name not in cls._registry:
raise KeyError(f"Dataset '{name}' is not registered.")
dataset_cls = cls._registry[name]
return dataset_cls(*args, **kwargs)
@classmethod
def available(cls) -> Dict[str, Type[BaseDataset]]:
return dict(cls._registry)

View File

@ -1,91 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
import cv2
import numpy as np
from skimage.morphology import skeletonize
def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]:
"""
Reads an RGB image and its mask counterpart.
"""
image_path = str(image_path)
mask_path = str(mask_path)
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"无法加载图像: {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
if mask is None:
raise ValueError(f"无法加载掩码: {mask_path}")
return image, mask
def extract_bboxes_from_mask(
mask: np.ndarray,
expand_ratio: float = 0.0,
min_area: int = 10,
) -> List[List[int]]:
"""
Extract bounding boxes from a binary mask using connected components.
"""
binary_mask = (mask > 0).astype(np.uint8)
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
bboxes: List[List[int]] = []
for i in range(1, num_labels):
x, y, w, h, area = stats[i]
if area < min_area:
continue
x1, y1 = x, y
x2, y2 = x + w, y + h
if expand_ratio > 0:
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
w_new = w * (1 + expand_ratio)
h_new = h * (1 + expand_ratio)
x1 = max(0, int(cx - w_new / 2))
y1 = max(0, int(cy - h_new / 2))
x2 = min(mask.shape[1], int(cx + w_new / 2))
y2 = min(mask.shape[0], int(cy + h_new / 2))
bboxes.append([x1, y1, x2, y2])
return bboxes
def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray:
"""
Sample points uniformly along the mask skeleton in (x, y) order.
"""
binary_mask = (mask > 0).astype(bool)
try:
skeleton = skeletonize(binary_mask)
except Exception:
skeleton = binary_mask
coords = np.argwhere(skeleton)
if coords.size == 0:
return np.zeros((0, 2), dtype=np.int32)
if coords.shape[0] <= num_points:
points = coords[:, [1, 0]]
return points.astype(np.int32)
indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int)
sampled = coords[indices][:, [1, 0]]
return sampled.astype(np.int32)
def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]:
"""
Sample points per connected component along each component's skeleton.
"""
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
all_points = []
for region_id in range(1, num_labels):
region_mask = (labels_map == region_id).astype(np.uint8) * 255
points = sample_points_on_skeleton(region_mask, num_points_per_component)
if len(points):
all_points.append(points)
if not all_points:
return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32)
stacked = np.vstack(all_points)
labels = np.ones(stacked.shape[0], dtype=np.int32)
return stacked, labels

View File

@ -1,14 +0,0 @@
from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall
from .pipeline_eval import PipelineEvaluator
from .reporting import write_csv, write_json
__all__ = [
"METRIC_REGISTRY",
"PipelineEvaluator",
"compute_dice",
"compute_iou",
"compute_precision",
"compute_recall",
"write_csv",
"write_json",
]

View File

@ -1,57 +0,0 @@
from __future__ import annotations
from typing import Callable, Dict, Iterable, Tuple
import numpy as np
def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
intersection = (pred_bin & target_bin).sum()
union = (pred_bin | target_bin).sum()
return float(intersection / union) if union else 0.0
def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
intersection = (pred_bin & target_bin).sum()
total = pred_bin.sum() + target_bin.sum()
return float((2 * intersection) / total) if total else 0.0
def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
tp = (pred_bin & target_bin).sum()
fp = (pred_bin & (1 - target_bin)).sum()
return float(tp / (tp + fp)) if (tp + fp) else 0.0
def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
pred_bin = (pred >= threshold).astype(np.uint8)
target_bin = (target > 0).astype(np.uint8)
tp = (pred_bin & target_bin).sum()
fn = ((1 - pred_bin) & target_bin).sum()
return float(tp / (tp + fn)) if (tp + fn) else 0.0
MetricFn = Callable[[np.ndarray, np.ndarray, float], float]
METRIC_REGISTRY: Dict[str, MetricFn] = {
"iou": compute_iou,
"dice": compute_dice,
"precision": compute_precision,
"recall": compute_recall,
}
def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]:
resolved: Dict[str, MetricFn] = {}
for name in metric_names:
if name not in METRIC_REGISTRY:
raise KeyError(f"Metric '{name}' is not registered.")
resolved[name] = METRIC_REGISTRY[name]
return resolved

View File

@ -1,95 +0,0 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
from tqdm import tqdm
from ..dataset import BaseDataset
from ..model import BaseModelAdapter
from ..model_configuration import EvaluationConfig
from .metrics import resolve_metrics
from .utils import extract_mask_from_pipeline_output
class PipelineEvaluator:
"""
Runs a Hugging Face pipeline across a dataset and aggregates metrics.
"""
def __init__(
self,
dataset: BaseDataset,
adapter: BaseModelAdapter,
config: EvaluationConfig,
) -> None:
self.dataset = dataset
self.adapter = adapter
self.config = config
self.metrics = resolve_metrics(config.metrics)
def run(self) -> Dict[str, Any]:
pipe = self.adapter.build_pipeline()
aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics}
output_dir = Path(self.config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
requested = self.config.max_samples or len(self.dataset)
total = min(requested, len(self.dataset))
prog_bar = tqdm(range(total), total=total)
for idx in prog_bar:
sample = self.dataset[idx]
inputs = self._build_pipeline_inputs(sample)
preds = pipe(**inputs)
labels = sample.get("labels", {})
mask = labels.get("mask")
if mask is None:
continue
prediction_mask = self._extract_mask(preds)
for metric_name, metric_fn in self.metrics.items():
for threshold in self.config.thresholds:
value = metric_fn(prediction_mask, mask, threshold)
aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value)
if self.config.save_predictions:
self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"])
summary = {
"metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()},
"config": self.config.__dict__,
"num_samples": total,
}
with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
return summary
def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]:
inputs: Dict[str, Any] = {"images": sample["pixel_values"]}
prompts = sample.get("prompts", {})
if "boxes" in prompts and prompts["boxes"]:
inputs["boxes"] = prompts["boxes"]
if "points" in prompts and prompts["points"]:
inputs["points"] = prompts["points"]
if "point_labels" in prompts and prompts["point_labels"]:
inputs["point_labels"] = prompts["point_labels"]
return inputs
def _extract_mask(self, pipeline_output: Any) -> np.ndarray:
"""
Normalize pipeline outputs into numpy masks.
"""
return extract_mask_from_pipeline_output(pipeline_output)
def _write_prediction(
self,
output_dir: Path,
index: int,
mask: np.ndarray,
metadata: Optional[Dict[str, Any]],
) -> None:
if metadata and "image_name" in metadata:
filename = metadata["image_name"].replace(".jpg", "_pred.npy")
else:
filename = f"sample_{index:04d}_pred.npy"
target_path = output_dir / "predictions"
target_path.mkdir(parents=True, exist_ok=True)
np.save(target_path / filename, mask)

View File

@ -1,25 +0,0 @@
from __future__ import annotations
import csv
import json
from pathlib import Path
from typing import Dict, Iterable
def write_json(summary: Dict, output_path: Path) -> None:
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w", encoding="utf-8") as handle:
json.dump(summary, handle, indent=2)
def write_csv(rows: Iterable[Dict], output_path: Path) -> None:
rows = list(rows)
if not rows:
return
output_path.parent.mkdir(parents=True, exist_ok=True)
fieldnames = sorted(rows[0].keys())
with output_path.open("w", encoding="utf-8", newline="") as handle:
writer = csv.DictWriter(handle, fieldnames=fieldnames)
writer.writeheader()
for row in rows:
writer.writerow(row)

View File

@ -1,55 +0,0 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import Optional
from transformers import HfArgumentParser
from ..dataset import DatasetRegistry
from ..model import ModelRegistry
from ..model_configuration import ConfigRegistry, EvaluationConfig
from .pipeline_eval import PipelineEvaluator
LOGGER = logging.getLogger(__name__)
@dataclass
class PipelineCLIArguments:
config_name: str = "sam2_bbox_prompt"
model_key: str = "sam2"
split: str = "test"
split_file: Optional[str] = None
device: Optional[str] = None
max_samples: Optional[int] = None
def main() -> None:
parser = HfArgumentParser(PipelineCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
project_config = ConfigRegistry.get(cli_args.config_name)
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
dataset = DatasetRegistry.create(
dataset_cfg.name,
config=dataset_cfg,
return_hf_dict=True,
)
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
evaluation_config = replace(
project_config.evaluation,
max_samples=cli_args.max_samples,
)
if cli_args.device:
adapter.build_pipeline(device=cli_args.device)
evaluator = PipelineEvaluator(
dataset=dataset,
adapter=adapter,
config=evaluation_config,
)
summary = evaluator.run()
LOGGER.info("Evaluation summary: %s", summary)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -1,16 +0,0 @@
from __future__ import annotations
from typing import Any
import numpy as np
def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray:
if isinstance(pipeline_output, list):
pipeline_output = pipeline_output[0]
mask = pipeline_output.get("mask")
if mask is None:
raise ValueError("Pipeline output missing 'mask'.")
if isinstance(mask, np.ndarray):
return mask
return np.array(mask)

View File

@ -1,7 +0,0 @@
"""
Backward-compatible wrapper that re-exports the predictor relocated to src.model.
"""
from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor
__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"]

View File

@ -1,17 +0,0 @@
from .base import BaseModelAdapter
from .inference import predict_with_bbox_prompt
from .predictor import HFSam2Predictor, build_hf_sam2_predictor
from .registry import ModelRegistry
from .sam2_adapter import Sam2ModelAdapter
from .trainer import FineTuningTrainer, TrainerArtifacts
__all__ = [
"BaseModelAdapter",
"FineTuningTrainer",
"HFSam2Predictor",
"ModelRegistry",
"Sam2ModelAdapter",
"TrainerArtifacts",
"build_hf_sam2_predictor",
"predict_with_bbox_prompt",
]

View File

@ -1,66 +0,0 @@
from __future__ import annotations
import abc
from typing import Any, Dict, Optional
from transformers import pipeline
from ..model_configuration import ModelConfig
class BaseModelAdapter(abc.ABC):
"""
Thin wrapper that standardizes how we instantiate models/processors/pipelines.
"""
task: str = "image-segmentation"
def __init__(self, config: ModelConfig) -> None:
self.config = config
self._model = None
self._processor = None
self._pipeline = None
def load_pretrained(self):
if self._model is None or self._processor is None:
self._model, self._processor = self._load_pretrained()
return self._model, self._processor
def build_pipeline(
self,
device: Optional[str] = None,
**kwargs,
):
if self._pipeline is None:
model, processor = self.load_pretrained()
pipe_kwargs = {
"task": self.task,
"model": model,
"image_processor": processor,
**self.config.pipeline_kwargs,
**kwargs,
}
if device is not None:
pipe_kwargs["device"] = device
self._pipeline = self._create_pipeline(pipe_kwargs)
return self._pipeline
async def build_pipeline_async(self, **kwargs):
"""
Async helper for future multi-device orchestration.
"""
return self.build_pipeline(**kwargs)
def save_pretrained(self, output_dir: str) -> None:
model, processor = self.load_pretrained()
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
@abc.abstractmethod
def _load_pretrained(self):
"""
Return (model, processor) tuple.
"""
def _create_pipeline(self, pipe_kwargs: Dict[str, Any]):
return pipeline(**pipe_kwargs)

View File

@ -1,32 +0,0 @@
from __future__ import annotations
from typing import List
import numpy as np
from .predictor import HFSam2Predictor
def predict_with_bbox_prompt(
predictor: HFSam2Predictor,
image: np.ndarray,
bboxes: List[np.ndarray],
) -> np.ndarray:
"""
Run SAM2 predictions for each bounding box and merge the masks.
"""
predictor.set_image(image)
if not bboxes:
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
for bbox in bboxes:
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=bbox,
multimask_output=False,
)
mask = masks[0]
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
combined_mask = combined_mask * 255
return combined_mask

View File

@ -1,158 +0,0 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import torch
from PIL import Image
from transformers import SamModel, SamProcessor
class HFSam2Predictor:
"""
Predictor wrapper around Hugging Face SAM2 models.
"""
def __init__(
self,
model_id: str = "facebook/sam2-hiera-small",
device: Optional[str] = None,
dtype: torch.dtype = torch.bfloat16,
) -> None:
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.dtype = dtype
self.model = SamModel.from_pretrained(model_id).to(self.device)
self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json")
self._override_processor_config()
if dtype == torch.bfloat16:
self.model = self.model.to(dtype=dtype)
self.model.eval()
self.current_image = None
self.current_image_embeddings = None
def set_image(self, image: np.ndarray) -> None:
if isinstance(image, np.ndarray):
pil_image = Image.fromarray(image.astype(np.uint8))
else:
pil_image = image
self.current_image = pil_image
with torch.inference_mode():
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
if self.dtype == torch.bfloat16:
inputs = {
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
for k, v in inputs.items()
}
self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
def predict(
self,
point_coords: Optional[np.ndarray] = None,
point_labels: Optional[np.ndarray] = None,
box: Optional[np.ndarray] = None,
multimask_output: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if self.current_image is None:
raise ValueError("No image set. Call set_image() first.")
input_points = self._prepare_points(point_coords)
input_labels = self._prepare_labels(point_labels)
input_boxes = self._prepare_boxes(box)
with torch.inference_mode():
inputs = self.processor(
images=self.current_image,
input_points=input_points,
input_labels=input_labels,
input_boxes=input_boxes,
return_tensors="pt",
).to(self.device)
if self.dtype == torch.bfloat16:
inputs = {
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
for k, v in inputs.items()
}
inputs.pop("pixel_values", None)
inputs["image_embeddings"] = self.current_image_embeddings
outputs = self.model(**inputs, multimask_output=multimask_output)
masks = self.processor.image_processor.post_process_masks(
outputs.pred_masks.float().cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu(),
)[0]
scores = outputs.iou_scores.float().cpu().numpy()[0]
masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8)
logits = outputs.pred_masks.float().cpu().numpy()[0]
return masks_np, scores, logits
def _prepare_points(self, coords: Optional[np.ndarray]):
"""
Points must be shaped (num_points, 2); wrap in outer batch dimension.
"""
if coords is None:
return None
coords_arr = np.asarray(coords)
if coords_arr.ndim == 1:
coords_arr = coords_arr[None, :]
if coords_arr.ndim != 2:
raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.")
return [coords_arr.tolist()]
def _prepare_labels(self, labels: Optional[np.ndarray]):
"""
Labels mirror the point dimension and are shaped (num_points,).
"""
if labels is None:
return None
labels_arr = np.asarray(labels)
if labels_arr.ndim == 0:
labels_arr = labels_arr[None]
if labels_arr.ndim != 1:
raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.")
return [labels_arr.tolist()]
def _prepare_boxes(self, boxes: Optional[np.ndarray]):
"""
HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4).
"""
if boxes is None:
return None
boxes_arr = np.asarray(boxes)
if boxes_arr.ndim == 1:
return [[boxes_arr.tolist()]]
if boxes_arr.ndim == 2:
return [boxes_arr.tolist()]
if boxes_arr.ndim == 3:
return boxes_arr.tolist()
raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.")
def _override_processor_config(self) -> None:
"""
Override processor config with local settings to avoid upstream regressions.
"""
config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json"
if not config_path.exists():
return
try:
config_dict = json.loads(config_path.read_text())
except Exception:
return
image_processor = getattr(self.processor, "image_processor", None)
if image_processor is None or not hasattr(image_processor, "config"):
return
# config behaves like a dict; update in-place.
try:
image_processor.config.update(config_dict)
except Exception:
for key, value in config_dict.items():
try:
setattr(image_processor.config, key, value)
except Exception:
continue
def build_hf_sam2_predictor(
model_id: str = "facebook/sam2-hiera-small",
device: Optional[str] = None,
) -> HFSam2Predictor:
return HFSam2Predictor(model_id=model_id, device=device)

View File

@ -1,33 +0,0 @@
from __future__ import annotations
from typing import Dict, Type
from ..model_configuration import ModelConfig
from .base import BaseModelAdapter
class ModelRegistry:
"""
Maps model keys to adapter classes so configs can reference them declaratively.
"""
_registry: Dict[str, Type[BaseModelAdapter]] = {}
@classmethod
def register(cls, name: str):
def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]:
cls._registry[name] = adapter_cls
return adapter_cls
return decorator
@classmethod
def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter:
if name not in cls._registry:
raise KeyError(f"ModelAdapter '{name}' is not registered.")
adapter_cls = cls._registry[name]
return adapter_cls(config)
@classmethod
def available(cls) -> Dict[str, Type[BaseModelAdapter]]:
return dict(cls._registry)

View File

@ -1,35 +0,0 @@
from __future__ import annotations
from typing import Any, Tuple
from transformers import AutoModelForImageSegmentation, AutoProcessor
from ..model_configuration import ModelConfig
from .base import BaseModelAdapter
from .registry import ModelRegistry
@ModelRegistry.register("sam2")
class Sam2ModelAdapter(BaseModelAdapter):
"""
Adapter that exposes SAM2 checkpoints through the HF pipeline interface.
"""
def __init__(self, config: ModelConfig) -> None:
super().__init__(config)
self.task = "image-segmentation"
def _load_pretrained(self) -> Tuple[Any, Any]:
model = AutoModelForImageSegmentation.from_pretrained(
self.config.name_or_path,
revision=self.config.revision,
cache_dir=self.config.cache_dir,
trust_remote_code=True,
)
processor = AutoProcessor.from_pretrained(
self.config.name_or_path,
revision=self.config.revision,
cache_dir=self.config.cache_dir,
trust_remote_code=True,
)
return model, processor

View File

@ -1,88 +0,0 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import Optional
from transformers import HfArgumentParser
from ..dataset import DatasetRegistry
from ..model_configuration import ConfigRegistry, DatasetConfig
from .registry import ModelRegistry
from .trainer import FineTuningTrainer
LOGGER = logging.getLogger(__name__)
@dataclass
class TrainCLIArguments:
config_name: str = "sam2_bbox_prompt"
model_key: str = "sam2"
train_split: str = "train"
eval_split: str = "val"
train_split_file: Optional[str] = None
eval_split_file: Optional[str] = None
skip_eval: bool = False
device: Optional[str] = None
def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
overrides = {}
if split:
overrides["split"] = split
if split_file:
overrides["split_file"] = split_file
return replace(config, **overrides)
def main() -> None:
parser = HfArgumentParser(TrainCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
project_config = ConfigRegistry.get(cli_args.config_name)
train_dataset_cfg = build_dataset(
project_config.dataset, cli_args.train_split, cli_args.train_split_file
)
eval_dataset_cfg = (
build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file)
if not cli_args.skip_eval
else None
)
train_dataset = DatasetRegistry.create(
train_dataset_cfg.name,
config=train_dataset_cfg,
return_hf_dict=True,
)
eval_dataset = (
DatasetRegistry.create(
eval_dataset_cfg.name,
config=eval_dataset_cfg,
return_hf_dict=True,
)
if eval_dataset_cfg
else None
)
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
if cli_args.device:
adapter.build_pipeline(device=cli_args.device)
trainer_builder = FineTuningTrainer(
adapter=adapter,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
training_config=project_config.training,
)
artifacts = trainer_builder.build()
LOGGER.info("Starting training with args: %s", artifacts.training_args)
train_result = artifacts.trainer.train()
LOGGER.info("Training finished: %s", train_result)
artifacts.trainer.save_model(project_config.training.output_dir)
if eval_dataset and not cli_args.skip_eval:
metrics = artifacts.trainer.evaluate()
LOGGER.info("Evaluation metrics: %s", metrics)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -1,64 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Optional
from transformers import Trainer, TrainingArguments
from ..dataset import BaseDataset, collate_samples
from ..model_configuration import TrainingConfig
from .base import BaseModelAdapter
@dataclass
class TrainerArtifacts:
trainer: Trainer
training_args: TrainingArguments
class FineTuningTrainer:
"""
Helper that bridges TrainingConfig + datasets + adapters into HF Trainer.
"""
def __init__(
self,
adapter: BaseModelAdapter,
train_dataset: Optional[BaseDataset],
eval_dataset: Optional[BaseDataset],
training_config: TrainingConfig,
trainer_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
self.adapter = adapter
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.training_config = training_config
self.trainer_kwargs = trainer_kwargs or {}
def build(self) -> TrainerArtifacts:
model, processor = self.adapter.load_pretrained()
training_args = TrainingArguments(
output_dir=self.training_config.output_dir,
num_train_epochs=self.training_config.num_train_epochs,
per_device_train_batch_size=self.training_config.per_device_train_batch_size,
per_device_eval_batch_size=self.training_config.per_device_eval_batch_size,
learning_rate=self.training_config.learning_rate,
gradient_accumulation_steps=self.training_config.gradient_accumulation_steps,
lr_scheduler_type=self.training_config.lr_scheduler_type,
warmup_ratio=self.training_config.warmup_ratio,
weight_decay=self.training_config.weight_decay,
seed=self.training_config.seed,
fp16=self.training_config.fp16,
bf16=self.training_config.bf16,
report_to=self.training_config.report_to,
)
hf_trainer = Trainer(
model=model,
args=training_args,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
data_collator=collate_samples,
tokenizer=processor,
**self.trainer_kwargs,
)
return TrainerArtifacts(trainer=hf_trainer, training_args=training_args)

View File

@ -1,22 +0,0 @@
from .config import (
DatasetConfig,
EvaluationConfig,
ModelConfig,
ProjectConfig,
TrainingConfig,
VisualizationConfig,
)
from .registry import ConfigRegistry
# ensure example configs register themselves
from . import sam2_bbox # noqa: F401
__all__ = [
"DatasetConfig",
"EvaluationConfig",
"ModelConfig",
"ProjectConfig",
"TrainingConfig",
"VisualizationConfig",
"ConfigRegistry",
]

View File

@ -1,89 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional
def _default_dict() -> Dict[str, Any]:
return {}
@dataclass
class DatasetConfig:
name: str
data_root: str
split: str = "test"
split_file: Optional[str] = None
annotation_file: Optional[str] = None
image_folder: Optional[str] = None
mask_folder: Optional[str] = None
extra_params: Dict[str, Any] = field(default_factory=_default_dict)
def resolve_path(self, relative: Optional[str]) -> Optional[Path]:
if relative is None:
return None
return Path(self.data_root) / relative
@dataclass
class ModelConfig:
name_or_path: str
revision: Optional[str] = None
config_name: Optional[str] = None
cache_dir: Optional[str] = None
prompt_type: str = "bbox"
image_size: Optional[int] = None
pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
@dataclass
class TrainingConfig:
output_dir: str = "./outputs"
num_train_epochs: float = 3.0
per_device_train_batch_size: int = 1
per_device_eval_batch_size: int = 1
learning_rate: float = 1e-4
weight_decay: float = 0.0
gradient_accumulation_steps: int = 1
lr_scheduler_type: str = "linear"
warmup_ratio: float = 0.0
seed: int = 42
fp16: bool = False
bf16: bool = False
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
@dataclass
class EvaluationConfig:
output_dir: str = "./results"
metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"])
thresholds: List[float] = field(default_factory=lambda: [0.5])
max_samples: Optional[int] = None
save_predictions: bool = True
@dataclass
class VisualizationConfig:
num_samples: int = 20
overlay_alpha: float = 0.6
save_dir: str = "./results/visualizations"
@dataclass
class ProjectConfig:
dataset: DatasetConfig
model: ModelConfig
training: TrainingConfig = field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
def to_dict(self) -> Dict[str, Any]:
return {
"dataset": self.dataset,
"model": self.model,
"training": self.training,
"evaluation": self.evaluation,
"visualization": self.visualization,
}

View File

@ -1,28 +0,0 @@
from __future__ import annotations
from typing import Dict
from .config import ProjectConfig
class ConfigRegistry:
"""
Stores reusable project configurations (dataset + model + training bundle).
"""
_registry: Dict[str, ProjectConfig] = {}
@classmethod
def register(cls, name: str, config: ProjectConfig) -> ProjectConfig:
cls._registry[name] = config
return config
@classmethod
def get(cls, name: str) -> ProjectConfig:
if name not in cls._registry:
raise KeyError(f"ProjectConfig '{name}' is not registered.")
return cls._registry[name]
@classmethod
def available(cls) -> Dict[str, ProjectConfig]:
return dict(cls._registry)

View File

@ -1,47 +0,0 @@
from __future__ import annotations
from .config import (
DatasetConfig,
EvaluationConfig,
ModelConfig,
ProjectConfig,
TrainingConfig,
VisualizationConfig,
)
from .registry import ConfigRegistry
SAM2_BBOX_CONFIG = ProjectConfig(
dataset=DatasetConfig(
name="crack500",
data_root="./crack500",
split="test",
split_file="test.txt",
image_folder="testcrop",
mask_folder="testdata",
),
model=ModelConfig(
name_or_path="facebook/sam2.1-hiera-small",
prompt_type="bbox",
pipeline_kwargs={"batch_size": 1},
),
training=TrainingConfig(
output_dir="./outputs/sam2_bbox",
num_train_epochs=5,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
learning_rate=1e-4,
gradient_accumulation_steps=4,
lr_scheduler_type="cosine",
),
evaluation=EvaluationConfig(
output_dir="./results/bbox_prompt",
thresholds=[0.3, 0.5, 0.75],
),
visualization=VisualizationConfig(
save_dir="./results/bbox_prompt/visualizations",
num_samples=20,
),
)
ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG)

View File

@ -1,5 +1,5 @@
"""
点提示方式的 SAM2 裂缝分割实现使用 HuggingFace Transformers
点提示方式的 SAM2 裂缝分割实现
使用骨架采样策略支持 1, 3, 5 个点
"""
@ -13,7 +13,8 @@ from tqdm import tqdm
import json
from skimage.morphology import skeletonize
from .hf_sam2_predictor import HFSam2Predictor
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
@ -128,7 +129,7 @@ def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np
def predict_with_point_prompt(
predictor: HFSam2Predictor,
predictor: SAM2ImagePredictor,
image: np.ndarray,
points: np.ndarray,
point_labels: np.ndarray = None
@ -137,7 +138,7 @@ def predict_with_point_prompt(
使用点提示进行 SAM2 预测
Args:
predictor: HFSam2Predictor 实例
predictor: SAM2ImagePredictor 实例
image: RGB 图像 (H, W, 3)
points: 点坐标 (N, 2)格式为 [x, y]
point_labels: 点标签 (N,)1 表示正样本0 表示负样本
@ -175,7 +176,7 @@ def predict_with_point_prompt(
def process_test_set(
data_root: str,
test_file: str,
predictor: HFSam2Predictor,
predictor: SAM2ImagePredictor,
output_dir: str,
num_points: int = 5,
per_component: bool = False
@ -186,7 +187,7 @@ def process_test_set(
Args:
data_root: 数据集根目录
test_file: 测试集文件路径 (test.txt)
predictor: HFSam2Predictor 实例
predictor: SAM2ImagePredictor 实例
output_dir: 输出目录
num_points: 采样点数量
per_component: 是否为每个连通域独立采样
@ -241,7 +242,7 @@ def process_test_set(
point_labels = np.ones(len(points), dtype=np.int32)
# 使用 SAM2 预测
with torch.inference_mode():
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
mask_pred = predict_with_point_prompt(
predictor, image, points, point_labels
)
@ -280,22 +281,24 @@ def main():
"""主函数"""
import argparse
parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - Crack500 数据集评估")
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID")
parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录")
parser.add_argument("--checkpoint", type=str, default="./sam2/checkpoints/sam2.1_hiera_small.pt", help="模型检查点")
parser.add_argument("--model_cfg", type=str, default="sam2.1_hiera_s.yaml", help="模型配置")
parser.add_argument("--output_dir", type=str, default="./results/point_prompt", help="输出目录")
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
args = parser.parse_args()
print("=" * 60)
print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
print("SAM2 点提示方式 - Crack500 数据集评估")
print("=" * 60)
print(f"数据集根目录: {args.data_root}")
print(f"测试集文件: {args.test_file}")
print(f"模型: {args.model_id}")
print(f"模型检查点: {args.checkpoint}")
print(f"模型配置: {args.model_cfg}")
print(f"采样点数量: {args.num_points}")
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
print(f"输出目录: {args.output_dir}")
@ -307,10 +310,10 @@ def main():
else:
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
# 构建 SAM2 predictor
# 构建 SAM2 模型
print("\n加载 SAM2 模型...")
from .hf_sam2_predictor import build_hf_sam2_predictor
predictor = build_hf_sam2_predictor(model_id=args.model_id)
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
predictor = SAM2ImagePredictor(sam2_model)
print("模型加载完成!")
# 处理测试集

View File

@ -1,8 +0,0 @@
from .config import TaskConfig, TaskStepConfig
from .pipeline import TaskRunner
from .registry import TaskRegistry
# ensure built-in tasks are registered
from . import examples # noqa: F401
__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"]

View File

@ -1,40 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional
TaskStepKind = Literal[
"train",
"evaluate",
"visualize",
"bbox_inference",
"point_inference",
"legacy_evaluation",
"legacy_visualization",
]
@dataclass
class TaskStepConfig:
kind: TaskStepKind
dataset_split: Optional[str] = None
dataset_split_file: Optional[str] = None
limit: Optional[int] = None
eval_split: Optional[str] = None
eval_split_file: Optional[str] = None
params: Dict[str, Any] = field(default_factory=dict)
@dataclass
class TaskConfig:
name: str
description: str
project_config_name: str
model_key: str = "sam2"
steps: List[TaskStepConfig] = field(default_factory=list)
dataset_overrides: Dict[str, Any] = field(default_factory=dict)
model_overrides: Dict[str, Any] = field(default_factory=dict)
training_overrides: Dict[str, Any] = field(default_factory=dict)
evaluation_overrides: Dict[str, Any] = field(default_factory=dict)
visualization_overrides: Dict[str, Any] = field(default_factory=dict)

View File

@ -1,34 +0,0 @@
from __future__ import annotations
from .config import TaskConfig, TaskStepConfig
from .registry import TaskRegistry
TaskRegistry.register(
TaskConfig(
name="sam2_crack500_eval",
description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.",
project_config_name="sam2_bbox_prompt",
steps=[
TaskStepConfig(kind="evaluate", dataset_split="test"),
TaskStepConfig(kind="visualize", dataset_split="test", limit=20),
],
)
)
TaskRegistry.register(
TaskConfig(
name="sam2_crack500_train_eval",
description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.",
project_config_name="sam2_bbox_prompt",
steps=[
TaskStepConfig(
kind="train",
dataset_split="train",
eval_split="val",
params={"num_train_epochs": 2},
),
TaskStepConfig(kind="evaluate", dataset_split="val", limit=32),
TaskStepConfig(kind="visualize", dataset_split="val", limit=16),
],
)
)

View File

@ -1,40 +0,0 @@
from __future__ import annotations
import tomllib
from pathlib import Path
from typing import Any, Dict, List
from .config import TaskConfig, TaskStepConfig
def load_task_from_toml(path: str | Path) -> TaskConfig:
"""
Load a TaskConfig from a TOML file.
"""
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
task_data = data.get("task", {})
steps_data: List[Dict[str, Any]] = data.get("steps", [])
steps = [
TaskStepConfig(
kind=step["kind"],
dataset_split=step.get("dataset_split"),
dataset_split_file=step.get("dataset_split_file"),
limit=step.get("limit"),
eval_split=step.get("eval_split"),
eval_split_file=step.get("eval_split_file"),
params=step.get("params", {}),
)
for step in steps_data
]
return TaskConfig(
name=task_data["name"],
description=task_data.get("description", ""),
project_config_name=task_data["project_config_name"],
model_key=task_data.get("model_key", "sam2"),
steps=steps,
dataset_overrides=task_data.get("dataset_overrides", {}),
model_overrides=task_data.get("model_overrides", {}),
training_overrides=task_data.get("training_overrides", {}),
evaluation_overrides=task_data.get("evaluation_overrides", {}),
visualization_overrides=task_data.get("visualization_overrides", {}),
)

View File

@ -1,264 +0,0 @@
from __future__ import annotations
import logging
from dataclasses import fields, replace
from pathlib import Path
from typing import Any, Dict, Optional
from ..bbox_prompt import process_test_set as bbox_process_test_set
from ..dataset import DatasetRegistry
from ..evaluation import PipelineEvaluator
from ..evaluation.utils import extract_mask_from_pipeline_output
from ..hf_sam2_predictor import build_hf_sam2_predictor
from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set
from ..legacy_visualization import (
create_metrics_distribution_plot,
visualize_test_set as legacy_visualize_test_set,
)
from ..model import FineTuningTrainer, ModelRegistry
from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig
from ..point_prompt import process_test_set as point_process_test_set
from ..visualization import OverlayGenerator
from .config import TaskConfig, TaskStepConfig
LOGGER = logging.getLogger(__name__)
def _replace_dataclass(instance, updates: Dict[str, Any]):
if not updates:
return instance
valid_fields = {f.name for f in fields(type(instance))}
filtered = {k: v for k, v in updates.items() if k in valid_fields}
if not filtered:
return instance
return replace(instance, **filtered)
def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
updates: Dict[str, Any] = {"split": split}
if split_file:
updates["split_file"] = split_file
return replace(config, **updates)
class TaskRunner:
"""
Sequentially executes a series of task steps (train/eval/visualize).
"""
def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None:
self.task_config = task_config
base_project = project_config or ConfigRegistry.get(task_config.project_config_name)
if project_config is None:
base_project = self._apply_project_overrides(base_project)
self.project_config = base_project
self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model)
def run(self) -> None:
LOGGER.info("Starting task '%s'", self.task_config.name)
for idx, step in enumerate(self.task_config.steps, start=1):
LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind)
if step.kind == "train":
self._run_train(step)
elif step.kind == "evaluate":
self._run_evaluate(step)
elif step.kind == "visualize":
self._run_visualize(step)
elif step.kind == "bbox_inference":
self._run_bbox_inference(step)
elif step.kind == "point_inference":
self._run_point_inference(step)
elif step.kind == "legacy_evaluation":
self._run_legacy_evaluation(step)
elif step.kind == "legacy_visualization":
self._run_legacy_visualization(step)
else:
raise ValueError(f"Unknown task step: {step.kind}")
def _build_dataset(self, split: str, split_file: Optional[str]):
dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file)
return DatasetRegistry.create(
dataset_cfg.name,
config=dataset_cfg,
return_hf_dict=True,
)
def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig:
dataset_cfg = config.dataset
if self.task_config.dataset_overrides:
dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides)
evaluation_cfg = config.evaluation
if self.task_config.evaluation_overrides:
evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides)
visualization_cfg = config.visualization
if self.task_config.visualization_overrides:
visualization_cfg = self._apply_simple_overrides(
visualization_cfg, self.task_config.visualization_overrides
)
model_cfg = config.model
if self.task_config.model_overrides:
model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides)
training_cfg = config.training
if self.task_config.training_overrides:
training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides)
return replace(
config,
dataset=dataset_cfg,
model=model_cfg,
training=training_cfg,
evaluation=evaluation_cfg,
visualization=visualization_cfg,
)
def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig:
overrides = dict(overrides)
extra_updates = overrides.pop("extra_params", {})
merged_extra = dict(dataset_cfg.extra_params or {})
merged_extra.update(extra_updates)
return replace(dataset_cfg, **overrides, extra_params=merged_extra)
def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]):
overrides = dict(overrides)
return replace(cfg, **overrides)
def _default_data_root(self) -> str:
return self.project_config.dataset.data_root
def _default_test_file(self) -> str:
dataset_cfg = self.project_config.dataset
candidate = dataset_cfg.split_file or "test.txt"
candidate_path = Path(candidate)
if candidate_path.is_absolute():
return str(candidate_path)
return str(Path(dataset_cfg.data_root) / candidate)
def _default_output_dir(self) -> str:
return self.project_config.evaluation.output_dir
def _run_train(self, step: TaskStepConfig) -> None:
train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
eval_dataset = None
if step.eval_split:
eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file)
trainer_builder = FineTuningTrainer(
adapter=self.adapter,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
training_config=_replace_dataclass(
self.project_config.training,
dict(step.params),
),
)
artifacts = trainer_builder.build()
train_result = artifacts.trainer.train()
LOGGER.info("Training result: %s", train_result)
artifacts.trainer.save_model(self.project_config.training.output_dir)
if eval_dataset:
metrics = artifacts.trainer.evaluate()
LOGGER.info("Evaluation metrics: %s", metrics)
def _run_evaluate(self, step: TaskStepConfig) -> None:
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
evaluation_cfg = _replace_dataclass(
self.project_config.evaluation,
{**dict(step.params), "max_samples": step.limit},
)
evaluator = PipelineEvaluator(
dataset=dataset,
adapter=self.adapter,
config=evaluation_cfg,
)
summary = evaluator.run()
LOGGER.info("Evaluation summary: %s", summary)
def _run_visualize(self, step: TaskStepConfig) -> None:
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
vis_config = _replace_dataclass(
self.project_config.visualization,
{**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples},
)
overlay = OverlayGenerator(vis_config)
pipe = self.adapter.build_pipeline()
limit = min(vis_config.num_samples, len(dataset))
for idx in range(limit):
sample = dataset[idx]
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
pred_mask = extract_mask_from_pipeline_output(preds)
mask = sample.get("labels", {}).get("mask")
overlay.visualize_sample(
image=sample["pixel_values"],
prediction=pred_mask,
mask=mask,
metadata=sample.get("metadata"),
)
LOGGER.info("Saved overlays to %s", vis_config.save_dir)
def _run_bbox_inference(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05))
output_dir = params.get("output_dir", self._default_output_dir())
model_id = params.get("model_id", self.project_config.model.name_or_path)
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
bbox_process_test_set(
data_root=data_root,
test_file=test_file,
predictor=predictor,
output_dir=output_dir,
expand_ratio=expand_ratio,
)
def _run_point_inference(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
num_points = params.get("num_points", 5)
per_component = params.get("per_component", False)
output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf"
model_id = params.get("model_id", self.project_config.model.name_or_path)
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
point_process_test_set(
data_root=data_root,
test_file=test_file,
predictor=predictor,
output_dir=output_dir,
num_points=num_points,
per_component=per_component,
)
def _run_legacy_evaluation(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
output_dir = params.get("output_dir", self._default_output_dir())
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
compute_skeleton = params.get("compute_skeleton", True)
legacy_evaluate_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
compute_skeleton=compute_skeleton,
)
def _run_legacy_visualization(self, step: TaskStepConfig) -> None:
params = dict(step.params)
data_root = params.get("data_root", self._default_data_root())
test_file = params.get("test_file", self._default_test_file())
output_dir = params.get("output_dir", self._default_output_dir())
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
num_samples = params.get("num_samples", 20)
save_all = params.get("save_all", False)
results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv"))
legacy_visualize_test_set(
data_root=data_root,
test_file=test_file,
pred_dir=pred_dir,
output_dir=output_dir,
results_csv=results_csv if Path(results_csv).exists() else None,
num_samples=num_samples,
save_all=save_all,
)
if params.get("create_metrics_plot", True):
create_metrics_distribution_plot(results_csv, output_dir)

View File

@ -1,28 +0,0 @@
from __future__ import annotations
from typing import Dict
from .config import TaskConfig
class TaskRegistry:
"""
Holds named task configs for reuse.
"""
_registry: Dict[str, TaskConfig] = {}
@classmethod
def register(cls, task: TaskConfig) -> TaskConfig:
cls._registry[task.name] = task
return task
@classmethod
def get(cls, name: str) -> TaskConfig:
if name not in cls._registry:
raise KeyError(f"Task '{name}' is not registered.")
return cls._registry[name]
@classmethod
def available(cls) -> Dict[str, TaskConfig]:
return dict(cls._registry)

View File

@ -1,44 +0,0 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Optional
from transformers import HfArgumentParser
from .config import TaskConfig
from .io import load_task_from_toml
from .pipeline import TaskRunner
from .registry import TaskRegistry
# ensure built-in tasks are registered when CLI runs
from . import examples # noqa: F401
LOGGER = logging.getLogger(__name__)
@dataclass
class TaskCLIArguments:
task_name: Optional[str] = None
task_file: Optional[str] = None
def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig:
if not cli_args.task_name and not cli_args.task_file:
raise ValueError("Provide either --task_name or --task_file.")
if cli_args.task_file:
return load_task_from_toml(cli_args.task_file)
return TaskRegistry.get(cli_args.task_name)
def main() -> None:
parser = HfArgumentParser(TaskCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
task = resolve_task(cli_args)
runner = TaskRunner(task)
runner.run()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -81,17 +81,17 @@ def create_comparison_figure(
# 原始图像
axes[0, 0].imshow(image)
axes[0, 0].set_title("Original Image", fontsize=16)
axes[0, 0].set_title("Original Image", fontsize=12)
axes[0, 0].axis('off')
# GT 掩码
axes[0, 1].imshow(mask_gt, cmap='gray')
axes[0, 1].set_title("Ground Truth", fontsize=16)
axes[0, 1].set_title("Ground Truth", fontsize=12)
axes[0, 1].axis('off')
# 预测掩码
axes[1, 0].imshow(mask_pred, cmap='gray')
axes[1, 0].set_title("Prediction", fontsize=16)
axes[1, 0].set_title("Prediction", fontsize=12)
axes[1, 0].axis('off')
# 叠加可视化
@ -110,16 +110,16 @@ def create_comparison_figure(
axes[1, 1].text(
0.02, 0.98, legend_text,
transform=axes[1, 1].transAxes,
fontsize=16,
fontsize=10,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
)
axes[1, 1].set_title("Overlay Visualization", fontsize=16)
axes[1, 1].set_title("Overlay Visualization", fontsize=12)
axes[1, 1].axis('off')
# # 设置总标题
# if title:
# fig.suptitle(title, fontsize=16, fontweight='bold')
# 设置总标题
if title:
fig.suptitle(title, fontsize=14, fontweight='bold')
plt.tight_layout()

View File

@ -1,4 +0,0 @@
from .gallery import build_gallery
from .overlay import OverlayGenerator
__all__ = ["OverlayGenerator", "build_gallery"]

View File

@ -1,28 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import Iterable
from PIL import Image
def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path:
"""
Simple grid composer that stitches overlay PNGs into a gallery.
"""
image_paths = list(image_paths)
if not image_paths:
raise ValueError("No images provided for gallery.")
output_path.parent.mkdir(parents=True, exist_ok=True)
images = [Image.open(path).convert("RGB") for path in image_paths]
widths, heights = zip(*(img.size for img in images))
cell_w = max(widths)
cell_h = max(heights)
rows = (len(images) + columns - 1) // columns
canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0))
for idx, img in enumerate(images):
row = idx // columns
col = idx % columns
canvas.paste(img, (col * cell_w, row * cell_h))
canvas.save(output_path)
return output_path

View File

@ -1,62 +0,0 @@
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Optional
import numpy as np
from PIL import Image
from ..model_configuration import VisualizationConfig
class OverlayGenerator:
"""
Turns model predictions into side-by-side overlays for quick inspection.
"""
def __init__(self, config: VisualizationConfig) -> None:
self.config = config
Path(self.config.save_dir).mkdir(parents=True, exist_ok=True)
def visualize_sample(
self,
image: np.ndarray,
prediction: np.ndarray,
mask: Optional[np.ndarray],
metadata: Optional[Dict[str, Any]] = None,
) -> Path:
overlay = self._compose_overlay(image, prediction, mask)
filename = (
metadata.get("image_name", "sample")
if metadata
else "sample"
)
target = Path(self.config.save_dir) / f"{filename}_overlay.png"
Image.fromarray(overlay).save(target)
return target
def _compose_overlay(
self,
image: np.ndarray,
prediction: np.ndarray,
mask: Optional[np.ndarray],
) -> np.ndarray:
vis = image.copy()
pred_mask = self._normalize(prediction)
color = np.zeros_like(vis)
color[..., 0] = pred_mask
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
if mask is not None:
gt = self._normalize(mask)
color = np.zeros_like(vis)
color[..., 1] = gt
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
return vis
def _normalize(self, array: np.ndarray) -> np.ndarray:
normalized = array.astype(np.float32)
normalized -= normalized.min()
denom = normalized.max() or 1.0
normalized = normalized / denom
normalized = (normalized * 255).astype(np.uint8)
return normalized

View File

@ -1,58 +0,0 @@
from __future__ import annotations
import logging
from dataclasses import dataclass, replace
from typing import Optional
from transformers import HfArgumentParser
from ..dataset import DatasetRegistry
from ..evaluation.utils import extract_mask_from_pipeline_output
from ..model import ModelRegistry
from ..model_configuration import ConfigRegistry
from .overlay import OverlayGenerator
LOGGER = logging.getLogger(__name__)
@dataclass
class VisualizationCLIArguments:
config_name: str = "sam2_bbox_prompt"
model_key: str = "sam2"
split: str = "test"
split_file: Optional[str] = None
num_samples: int = 20
device: Optional[str] = None
def main() -> None:
parser = HfArgumentParser(VisualizationCLIArguments)
(cli_args,) = parser.parse_args_into_dataclasses()
project_config = ConfigRegistry.get(cli_args.config_name)
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
dataset = DatasetRegistry.create(
dataset_cfg.name,
config=dataset_cfg,
return_hf_dict=True,
)
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
overlay = OverlayGenerator(project_config.visualization)
pipe = adapter.build_pipeline(device=cli_args.device)
limit = min(cli_args.num_samples, len(dataset))
for idx in range(limit):
sample = dataset[idx]
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
pred_mask = extract_mask_from_pipeline_output(preds)
mask = sample.get("labels", {}).get("mask")
overlay.visualize_sample(
image=sample["pixel_values"],
prediction=pred_mask,
mask=mask,
metadata=sample.get("metadata"),
)
LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -1,34 +0,0 @@
[task]
name = "bbox_cli_template"
description = "Run legacy bbox-prompt inference + evaluation + visualization"
project_config_name = "sam2_bbox_prompt"
[[steps]]
kind = "bbox_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
output_dir = "./results/bbox_prompt"
expand_ratio = 0.05
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/bbox_prompt"
pred_dir = "./results/bbox_prompt/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/bbox_prompt"
pred_dir = "./results/bbox_prompt/predictions"
results_csv = "./results/bbox_prompt/evaluation_results.csv"
num_samples = 20
save_all = false
create_metrics_plot = true

View File

@ -1,100 +0,0 @@
[task]
name = "point_cli_template"
description = "Run legacy point-prompt inference/eval/visualization for multiple configs"
project_config_name = "sam2_bbox_prompt"
# 1 point config
[[steps]]
kind = "point_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
num_points = 1
per_component = false
output_dir = "./results/point_prompt_1pts_hf"
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_1pts_hf"
pred_dir = "./results/point_prompt_1pts_hf/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_1pts_hf"
pred_dir = "./results/point_prompt_1pts_hf/predictions"
results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv"
num_samples = 10
save_all = false
create_metrics_plot = true
# 3 point config
[[steps]]
kind = "point_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
num_points = 3
per_component = false
output_dir = "./results/point_prompt_3pts_hf"
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_3pts_hf"
pred_dir = "./results/point_prompt_3pts_hf/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_3pts_hf"
pred_dir = "./results/point_prompt_3pts_hf/predictions"
results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv"
num_samples = 10
save_all = false
create_metrics_plot = true
# 5 point config
[[steps]]
kind = "point_inference"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
model_id = "facebook/sam2-hiera-small"
num_points = 5
per_component = false
output_dir = "./results/point_prompt_5pts_hf"
[[steps]]
kind = "legacy_evaluation"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_5pts_hf"
pred_dir = "./results/point_prompt_5pts_hf/predictions"
compute_skeleton = true
[[steps]]
kind = "legacy_visualization"
[steps.params]
data_root = "./crack500"
test_file = "./crack500/test.txt"
output_dir = "./results/point_prompt_5pts_hf"
pred_dir = "./results/point_prompt_5pts_hf/predictions"
results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv"
num_samples = 10
save_all = false
create_metrics_plot = true