revert to oldest implemention
This commit is contained in:
parent
4886fc8861
commit
5371973442
@ -1,35 +0,0 @@
|
||||
{
|
||||
"crop_size": null,
|
||||
"data_format": "channels_first",
|
||||
"default_to_square": true,
|
||||
"device": null,
|
||||
"disable_grouping": null,
|
||||
"do_center_crop": null,
|
||||
"do_convert_rgb": true,
|
||||
"do_normalize": false,
|
||||
"do_rescale": false,
|
||||
"do_resize": false,
|
||||
"image_mean": [
|
||||
0.485,
|
||||
0.456,
|
||||
0.406
|
||||
],
|
||||
"image_processor_type": "Sam2ImageProcessorFast",
|
||||
"image_std": [
|
||||
0.229,
|
||||
0.224,
|
||||
0.225
|
||||
],
|
||||
"input_data_format": null,
|
||||
"mask_size": {
|
||||
"height": 256,
|
||||
"width": 256
|
||||
},
|
||||
"processor_class": "Sam2VideoProcessor",
|
||||
"resample": 2,
|
||||
"rescale_factor": 0.00392156862745098,
|
||||
"return_tensors": null,
|
||||
"size": {
|
||||
"longest_edge": 1024
|
||||
}
|
||||
}
|
||||
98
pixi.lock
generated
98
pixi.lock
generated
@ -34,6 +34,7 @@ environments:
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda
|
||||
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||
@ -48,8 +49,10 @@ environments:
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
|
||||
@ -81,6 +84,7 @@ environments:
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
@ -88,6 +92,7 @@ environments:
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4f/87/424511bdcd02c8d7acf9f65caa09f291a519b16bd83c3fb3374b3d4ae951/pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
|
||||
@ -122,6 +127,7 @@ environments:
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
||||
- pypi: /home/dustella/projects/sam2
|
||||
packages:
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
|
||||
sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726
|
||||
@ -144,6 +150,12 @@ packages:
|
||||
purls: []
|
||||
size: 23621
|
||||
timestamp: 1650670423406
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||
name: antlr4-python3-runtime
|
||||
version: 4.9.3
|
||||
sha256: f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b
|
||||
requires_dist:
|
||||
- typing ; python_full_version < '3.5'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||
name: asttokens
|
||||
version: 3.0.1
|
||||
@ -540,6 +552,15 @@ packages:
|
||||
- types-tqdm ; extra == 'typing'
|
||||
- types-urllib3 ; extra == 'typing'
|
||||
requires_python: '>=3.8.0'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl
|
||||
name: hydra-core
|
||||
version: 1.3.2
|
||||
sha256: fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b
|
||||
requires_dist:
|
||||
- omegaconf>=2.2,<2.4
|
||||
- antlr4-python3-runtime==4.9.*
|
||||
- packaging
|
||||
- importlib-resources ; python_full_version < '3.9'
|
||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.1-h33c6efd_0.conda
|
||||
sha256: 7d6463d0be5092b2ae8f2fad34dc84de83eab8bd44cc0d4be8931881c973c48f
|
||||
md5: 518e9bbbc3e3486d6a4519192ba690f8
|
||||
@ -623,6 +644,17 @@ packages:
|
||||
- sphinx<6 ; extra == 'full'
|
||||
- tifffile ; extra == 'full'
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz
|
||||
name: iopath
|
||||
version: 0.1.10
|
||||
sha256: 3311c16a4d9137223e20f141655759933e1eda24f8bff166af834af3c645ef01
|
||||
requires_dist:
|
||||
- tqdm
|
||||
- typing-extensions
|
||||
- portalocker
|
||||
- dataclasses ; python_full_version < '3.7'
|
||||
- boto3 ; extra == 'aws'
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
||||
name: ipykernel
|
||||
version: 7.1.0
|
||||
@ -1174,6 +1206,15 @@ packages:
|
||||
version: 12.8.90
|
||||
sha256: 5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f
|
||||
requires_python: '>=3'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
|
||||
name: omegaconf
|
||||
version: 2.3.0
|
||||
sha256: 7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b
|
||||
requires_dist:
|
||||
- antlr4-python3-runtime==4.9.*
|
||||
- pyyaml>=5.1.0
|
||||
- dataclasses ; python_full_version == '3.6.*'
|
||||
requires_python: '>=3.6'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||
name: opencv-python
|
||||
version: 4.12.0.88
|
||||
@ -1355,6 +1396,25 @@ packages:
|
||||
- pytest>=8.4.2 ; extra == 'test'
|
||||
- mypy>=1.18.2 ; extra == 'type'
|
||||
requires_python: '>=3.10'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl
|
||||
name: portalocker
|
||||
version: 3.2.0
|
||||
sha256: 3cdc5f565312224bc570c49337bd21428bba0ef363bbcf58b9ef4a9f11779968
|
||||
requires_dist:
|
||||
- pywin32>=226 ; sys_platform == 'win32'
|
||||
- portalocker[tests] ; extra == 'docs'
|
||||
- coverage-conditional-plugin>=0.9.0 ; extra == 'tests'
|
||||
- portalocker[redis] ; extra == 'tests'
|
||||
- pytest-cov>=2.8.1 ; extra == 'tests'
|
||||
- pytest-mypy>=0.8.0 ; extra == 'tests'
|
||||
- pytest-rerunfailures>=15.0 ; extra == 'tests'
|
||||
- pytest-timeout>=2.1.0 ; extra == 'tests'
|
||||
- pytest>=5.4.1 ; extra == 'tests'
|
||||
- sphinx>=6.0.0 ; extra == 'tests'
|
||||
- types-pywin32>=310.0.0.20250429 ; extra == 'tests'
|
||||
- types-redis ; extra == 'tests'
|
||||
- redis ; extra == 'redis'
|
||||
requires_python: '>=3.9'
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
||||
name: prompt-toolkit
|
||||
version: 3.0.52
|
||||
@ -1544,6 +1604,44 @@ packages:
|
||||
- safetensors[testing] ; extra == 'all'
|
||||
- safetensors[all] ; extra == 'dev'
|
||||
requires_python: '>=3.9'
|
||||
- pypi: /home/dustella/projects/sam2
|
||||
name: sam-2
|
||||
version: '1.0'
|
||||
sha256: f3cc0abfc266b6f57a457d26bdf52fcaa92567d8fa27ba6ff59ebe60bec0ac69
|
||||
requires_dist:
|
||||
- torch>=2.5.1
|
||||
- torchvision>=0.20.1
|
||||
- numpy>=1.24.4
|
||||
- tqdm>=4.66.1
|
||||
- hydra-core>=1.3.2
|
||||
- iopath>=0.1.10
|
||||
- pillow>=9.4.0
|
||||
- matplotlib>=3.9.1 ; extra == 'notebooks'
|
||||
- jupyter>=1.0.0 ; extra == 'notebooks'
|
||||
- opencv-python>=4.7.0 ; extra == 'notebooks'
|
||||
- eva-decord>=0.6.1 ; extra == 'notebooks'
|
||||
- flask>=3.0.3 ; extra == 'interactive-demo'
|
||||
- flask-cors>=5.0.0 ; extra == 'interactive-demo'
|
||||
- av>=13.0.0 ; extra == 'interactive-demo'
|
||||
- dataclasses-json>=0.6.7 ; extra == 'interactive-demo'
|
||||
- eva-decord>=0.6.1 ; extra == 'interactive-demo'
|
||||
- gunicorn>=23.0.0 ; extra == 'interactive-demo'
|
||||
- imagesize>=1.4.1 ; extra == 'interactive-demo'
|
||||
- pycocotools>=2.0.8 ; extra == 'interactive-demo'
|
||||
- strawberry-graphql>=0.243.0 ; extra == 'interactive-demo'
|
||||
- black==24.2.0 ; extra == 'dev'
|
||||
- usort==1.0.2 ; extra == 'dev'
|
||||
- ufmt==2.0.0b2 ; extra == 'dev'
|
||||
- fvcore>=0.1.5.post20221221 ; extra == 'dev'
|
||||
- pandas>=2.2.2 ; extra == 'dev'
|
||||
- scikit-image>=0.24.0 ; extra == 'dev'
|
||||
- tensorboard>=2.17.0 ; extra == 'dev'
|
||||
- pycocotools>=2.0.8 ; extra == 'dev'
|
||||
- tensordict>=0.6.0 ; extra == 'dev'
|
||||
- opencv-python>=4.7.0 ; extra == 'dev'
|
||||
- submitit>=1.5.1 ; extra == 'dev'
|
||||
requires_python: '>=3.10.0'
|
||||
editable: true
|
||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f4/a2/70401a107d6d7466d64b466927e6b96fcefa99d57494b972608e2f8be50f/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||
name: scikit-image
|
||||
version: 0.26.0
|
||||
|
||||
@ -26,3 +26,4 @@ tqdm = ">=4.65.0"
|
||||
pandas = ">=2.0.0"
|
||||
transformers = ">=4.57.3, <5"
|
||||
ipykernel = ">=7.1.0, <8"
|
||||
sam-2 = { path = "/home/dustella/projects/sam2", editable = true }
|
||||
|
||||
@ -1,136 +1,248 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
|
||||
SAM2 边界框提示方式完整评估流程
|
||||
包括:推理 -> 评估 -> 可视化
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from src.tasks.config import TaskConfig, TaskStepConfig
|
||||
from src.tasks.io import load_task_from_toml
|
||||
from src.tasks.pipeline import TaskRunner
|
||||
# 添加 src 目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor
|
||||
from evaluation import evaluate_test_set
|
||||
from visualization import visualize_test_set, create_metrics_distribution_plot
|
||||
|
||||
|
||||
@dataclass
|
||||
class BBoxCLIArgs:
|
||||
data_root: str
|
||||
test_file: str
|
||||
model_id: str
|
||||
output_dir: str
|
||||
expand_ratio: float
|
||||
num_vis: int
|
||||
vis_all: bool
|
||||
skip_inference: bool
|
||||
skip_evaluation: bool
|
||||
skip_visualization: bool
|
||||
config_name: str
|
||||
task_file: Optional[str]
|
||||
|
||||
|
||||
def parse_args() -> BBoxCLIArgs:
|
||||
def parse_args():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
|
||||
description="SAM2 边界框提示方式 - Crack500 数据集完整评估"
|
||||
)
|
||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||||
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
|
||||
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
|
||||
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
|
||||
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
|
||||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||||
|
||||
# 数据集参数
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="sam2_bbox_prompt",
|
||||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||||
"--data_root", type=str, default="./crack500",
|
||||
help="数据集根目录"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
|
||||
"--test_file", type=str, default="./crack500/test.txt",
|
||||
help="测试集文件路径"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return BBoxCLIArgs(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
model_id=args.model_id,
|
||||
output_dir=args.output_dir,
|
||||
expand_ratio=args.expand_ratio,
|
||||
num_vis=args.num_vis,
|
||||
vis_all=args.vis_all,
|
||||
skip_inference=args.skip_inference,
|
||||
skip_evaluation=args.skip_evaluation,
|
||||
skip_visualization=args.skip_visualization,
|
||||
config_name=args.config_name,
|
||||
task_file=args.task_file,
|
||||
|
||||
# 模型参数
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
|
||||
help="SAM2 模型检查点路径"
|
||||
)
|
||||
|
||||
|
||||
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
|
||||
steps: List[TaskStepConfig] = []
|
||||
common = {
|
||||
"data_root": args.data_root,
|
||||
"test_file": args.test_file,
|
||||
"model_id": args.model_id,
|
||||
"output_dir": args.output_dir,
|
||||
}
|
||||
if not args.skip_inference:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="bbox_inference",
|
||||
params={**common, "expand_ratio": args.expand_ratio},
|
||||
)
|
||||
)
|
||||
if not args.skip_evaluation:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_evaluation",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{args.output_dir}/predictions",
|
||||
"compute_skeleton": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
if not args.skip_visualization:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_visualization",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{args.output_dir}/predictions",
|
||||
"results_csv": f"{args.output_dir}/evaluation_results.csv",
|
||||
"num_samples": args.num_vis,
|
||||
"save_all": args.vis_all,
|
||||
"create_metrics_plot": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
return TaskConfig(
|
||||
name="bbox_cli_run",
|
||||
description="Legacy bbox prompt pipeline executed via TaskRunner",
|
||||
project_config_name=args.config_name,
|
||||
steps=steps,
|
||||
parser.add_argument(
|
||||
"--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml",
|
||||
help="SAM2 模型配置文件"
|
||||
)
|
||||
|
||||
# 输出参数
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="./results/bbox_prompt",
|
||||
help="输出目录"
|
||||
)
|
||||
|
||||
# 边界框参数
|
||||
parser.add_argument(
|
||||
"--expand_ratio", type=float, default=0.05,
|
||||
help="边界框扩展比例 (0.0-1.0)"
|
||||
)
|
||||
|
||||
# 可视化参数
|
||||
parser.add_argument(
|
||||
"--num_vis", type=int, default=20,
|
||||
help="可视化样本数量"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vis_all", action="store_true",
|
||||
help="可视化所有样本"
|
||||
)
|
||||
|
||||
# 流程控制
|
||||
parser.add_argument(
|
||||
"--skip_inference", action="store_true",
|
||||
help="跳过推理步骤(使用已有预测结果)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_evaluation", action="store_true",
|
||||
help="跳过评估步骤"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_visualization", action="store_true",
|
||||
help="跳过可视化步骤"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
def main():
|
||||
"""主函数"""
|
||||
args = parse_args()
|
||||
if args.task_file:
|
||||
task = load_task_from_toml(args.task_file)
|
||||
|
||||
print("=" * 80)
|
||||
print("SAM2 边界框提示方式 - Crack500 数据集完整评估")
|
||||
print("=" * 80)
|
||||
print(f"数据集根目录: {args.data_root}")
|
||||
print(f"测试集文件: {args.test_file}")
|
||||
print(f"模型检查点: {args.checkpoint}")
|
||||
print(f"模型配置: {args.model_cfg}")
|
||||
print(f"边界框扩展比例: {args.expand_ratio * 100}%")
|
||||
print(f"输出目录: {args.output_dir}")
|
||||
print("=" * 80)
|
||||
|
||||
# 检查必要文件
|
||||
if not os.path.exists(args.data_root):
|
||||
print(f"\n错误: 数据集目录不存在 {args.data_root}")
|
||||
return
|
||||
|
||||
if not os.path.exists(args.test_file):
|
||||
print(f"\n错误: 测试集文件不存在 {args.test_file}")
|
||||
return
|
||||
|
||||
# 创建输出目录
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# ========== 步骤 1: 推理 ==========
|
||||
if not args.skip_inference:
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤 1/3: 使用 SAM2 进行推理")
|
||||
print("=" * 80)
|
||||
|
||||
# 检查模型文件
|
||||
if not os.path.exists(args.checkpoint):
|
||||
print(f"\n错误: 模型检查点不存在 {args.checkpoint}")
|
||||
print("请先下载 SAM2 模型权重!")
|
||||
print("运行: cd sam2/checkpoints && ./download_ckpts.sh")
|
||||
return
|
||||
|
||||
try:
|
||||
import torch
|
||||
|
||||
# 检查 CUDA
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||
else:
|
||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# 加载模型
|
||||
print("\n加载 SAM2 模型...")
|
||||
start_time = time.time()
|
||||
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
|
||||
predictor = SAM2ImagePredictor(sam2_model)
|
||||
print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s")
|
||||
|
||||
# 处理测试集
|
||||
print("\n开始推理...")
|
||||
start_time = time.time()
|
||||
results = process_test_set(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
predictor=predictor,
|
||||
output_dir=args.output_dir,
|
||||
expand_ratio=args.expand_ratio
|
||||
)
|
||||
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
|
||||
print(f"成功处理 {len(results)} 张图像")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n推理过程出错: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
else:
|
||||
task = build_cli_task(args)
|
||||
if not task.steps:
|
||||
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
|
||||
runner = TaskRunner(task)
|
||||
runner.run()
|
||||
print("\n跳过推理步骤(使用已有预测结果)")
|
||||
|
||||
# ========== 步骤 2: 评估 ==========
|
||||
if not args.skip_evaluation:
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤 2/3: 评估预测结果")
|
||||
print("=" * 80)
|
||||
|
||||
pred_dir = os.path.join(args.output_dir, "predictions")
|
||||
|
||||
if not os.path.exists(pred_dir):
|
||||
print(f"\n错误: 预测目录不存在 {pred_dir}")
|
||||
print("请先运行推理步骤!")
|
||||
return
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
df_results = evaluate_test_set(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=args.output_dir,
|
||||
compute_skeleton=True
|
||||
)
|
||||
print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n评估过程出错: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
else:
|
||||
print("\n跳过评估步骤")
|
||||
|
||||
# ========== 步骤 3: 可视化 ==========
|
||||
if not args.skip_visualization:
|
||||
print("\n" + "=" * 80)
|
||||
print("步骤 3/3: 生成可视化结果")
|
||||
print("=" * 80)
|
||||
|
||||
pred_dir = os.path.join(args.output_dir, "predictions")
|
||||
results_csv = os.path.join(args.output_dir, "evaluation_results.csv")
|
||||
|
||||
if not os.path.exists(pred_dir):
|
||||
print(f"\n错误: 预测目录不存在 {pred_dir}")
|
||||
return
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# 可视化样本
|
||||
visualize_test_set(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=args.output_dir,
|
||||
results_csv=results_csv if os.path.exists(results_csv) else None,
|
||||
num_samples=args.num_vis,
|
||||
save_all=args.vis_all
|
||||
)
|
||||
|
||||
# 创建指标分布图
|
||||
if os.path.exists(results_csv):
|
||||
create_metrics_distribution_plot(results_csv, args.output_dir)
|
||||
|
||||
print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n可视化过程出错: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
else:
|
||||
print("\n跳过可视化步骤")
|
||||
|
||||
# ========== 完成 ==========
|
||||
print("\n" + "=" * 80)
|
||||
print("所有步骤完成!")
|
||||
print("=" * 80)
|
||||
print(f"\n结果保存在: {args.output_dir}")
|
||||
print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}")
|
||||
print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}")
|
||||
print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}")
|
||||
print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}")
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,222 +1,272 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
|
||||
SAM2 点提示方式完整评估流程
|
||||
支持 1, 3, 5 个点的对比实验
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import sys
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pandas as pd
|
||||
# 添加 src 目录到路径
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||
|
||||
from src.tasks.config import TaskConfig, TaskStepConfig
|
||||
from src.tasks.io import load_task_from_toml
|
||||
from src.tasks.pipeline import TaskRunner
|
||||
from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor
|
||||
from evaluation import evaluate_test_set
|
||||
from visualization import visualize_test_set, create_metrics_distribution_plot
|
||||
|
||||
|
||||
@dataclass
|
||||
class PointCLIArgs:
|
||||
data_root: str
|
||||
test_file: str
|
||||
model_id: str
|
||||
point_configs: List[int]
|
||||
per_component: bool
|
||||
num_vis: int
|
||||
skip_inference: bool
|
||||
skip_evaluation: bool
|
||||
skip_visualization: bool
|
||||
skip_comparison: bool
|
||||
comparison_dir: str
|
||||
config_name: str
|
||||
task_file: Optional[str]
|
||||
|
||||
|
||||
def parse_args() -> PointCLIArgs:
|
||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
|
||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
||||
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
|
||||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
|
||||
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
|
||||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
||||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
||||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
||||
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
|
||||
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
|
||||
parser.add_argument(
|
||||
"--config_name",
|
||||
type=str,
|
||||
default="sam2_bbox_prompt",
|
||||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return PointCLIArgs(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
model_id=args.model_id,
|
||||
point_configs=args.point_configs,
|
||||
per_component=args.per_component,
|
||||
num_vis=args.num_vis,
|
||||
skip_inference=args.skip_inference,
|
||||
skip_evaluation=args.skip_evaluation,
|
||||
skip_visualization=args.skip_visualization,
|
||||
skip_comparison=args.skip_comparison,
|
||||
comparison_dir=args.comparison_dir,
|
||||
config_name=args.config_name,
|
||||
task_file=args.task_file,
|
||||
)
|
||||
|
||||
|
||||
def default_output_dir(num_points: int, per_component: bool) -> str:
|
||||
def run_single_experiment(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
checkpoint: str,
|
||||
model_cfg: str,
|
||||
num_points: int,
|
||||
per_component: bool = False
|
||||
):
|
||||
"""运行单个点数的实验"""
|
||||
|
||||
# 设置输出目录
|
||||
if per_component:
|
||||
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
|
||||
return f"./results/point_prompt_{num_points}pts_hf"
|
||||
|
||||
|
||||
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
|
||||
steps: List[TaskStepConfig] = []
|
||||
common = {
|
||||
"data_root": args.data_root,
|
||||
"test_file": args.test_file,
|
||||
"model_id": args.model_id,
|
||||
"output_dir": output_dir,
|
||||
}
|
||||
if not args.skip_inference:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="point_inference",
|
||||
params={
|
||||
**common,
|
||||
"num_points": num_points,
|
||||
"per_component": args.per_component,
|
||||
},
|
||||
)
|
||||
)
|
||||
if not args.skip_evaluation:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_evaluation",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{output_dir}/predictions",
|
||||
"compute_skeleton": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
if not args.skip_visualization:
|
||||
steps.append(
|
||||
TaskStepConfig(
|
||||
kind="legacy_visualization",
|
||||
params={
|
||||
**common,
|
||||
"pred_dir": f"{output_dir}/predictions",
|
||||
"results_csv": f"{output_dir}/evaluation_results.csv",
|
||||
"num_samples": args.num_vis,
|
||||
"save_all": False,
|
||||
"create_metrics_plot": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
return TaskConfig(
|
||||
name=f"point_cli_{num_points}",
|
||||
description=f"Legacy point prompt pipeline ({num_points} pts)",
|
||||
project_config_name=args.config_name,
|
||||
steps=steps,
|
||||
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
|
||||
else:
|
||||
output_dir = f"./results/point_prompt_{num_points}pts"
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})")
|
||||
print("=" * 80)
|
||||
|
||||
# 加载模型
|
||||
print("\n加载 SAM2 模型...")
|
||||
import torch
|
||||
sam2_model = build_sam2(model_cfg, checkpoint)
|
||||
predictor = SAM2ImagePredictor(sam2_model)
|
||||
print("模型加载完成!")
|
||||
|
||||
# 推理
|
||||
print(f"\n步骤 1/3: 推理 ({num_points} 个点)")
|
||||
start_time = time.time()
|
||||
results = process_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
predictor=predictor,
|
||||
output_dir=output_dir,
|
||||
num_points=num_points,
|
||||
per_component=per_component
|
||||
)
|
||||
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
|
||||
|
||||
# 评估
|
||||
print(f"\n步骤 2/3: 评估")
|
||||
pred_dir = os.path.join(output_dir, "predictions")
|
||||
start_time = time.time()
|
||||
df_results = evaluate_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=output_dir,
|
||||
compute_skeleton=True
|
||||
)
|
||||
print(f"评估完成!耗时: {time.time() - start_time:.2f}s")
|
||||
|
||||
# 可视化
|
||||
print(f"\n步骤 3/3: 可视化")
|
||||
results_csv = os.path.join(output_dir, "evaluation_results.csv")
|
||||
start_time = time.time()
|
||||
visualize_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=output_dir,
|
||||
results_csv=results_csv,
|
||||
num_samples=10,
|
||||
save_all=False
|
||||
)
|
||||
create_metrics_distribution_plot(results_csv, output_dir)
|
||||
print(f"可视化完成!耗时: {time.time() - start_time:.2f}s")
|
||||
|
||||
return df_results
|
||||
|
||||
|
||||
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
|
||||
csv_path = Path(output_dir) / "evaluation_results.csv"
|
||||
if not csv_path.exists():
|
||||
return None
|
||||
return pd.read_csv(csv_path)
|
||||
|
||||
|
||||
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
|
||||
if not results:
|
||||
return
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
summary_rows = []
|
||||
for num_points, df in results.items():
|
||||
summary_rows.append(
|
||||
{
|
||||
"num_points": num_points,
|
||||
"iou_mean": df["iou"].mean(),
|
||||
"iou_std": df["iou"].std(),
|
||||
"dice_mean": df["dice"].mean(),
|
||||
"dice_std": df["dice"].std(),
|
||||
"f1_mean": df["f1_score"].mean(),
|
||||
"f1_std": df["f1_score"].std(),
|
||||
"precision_mean": df["precision"].mean(),
|
||||
"recall_mean": df["recall"].mean(),
|
||||
}
|
||||
)
|
||||
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
|
||||
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
|
||||
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df_summary.to_csv(summary_path, index=False)
|
||||
|
||||
def compare_results(results_dict: dict, output_dir: str = "./results"):
|
||||
"""对比不同点数的结果"""
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
metrics_to_plot = [
|
||||
("iou_mean", "iou_std", "IoU"),
|
||||
("dice_mean", "dice_std", "Dice"),
|
||||
("f1_mean", "f1_std", "F1-Score"),
|
||||
]
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("对比不同点数的性能")
|
||||
print("=" * 80)
|
||||
|
||||
# 收集所有结果
|
||||
summary = []
|
||||
for num_points, df in results_dict.items():
|
||||
metrics = {
|
||||
'num_points': num_points,
|
||||
'iou_mean': df['iou'].mean(),
|
||||
'iou_std': df['iou'].std(),
|
||||
'dice_mean': df['dice'].mean(),
|
||||
'dice_std': df['dice'].std(),
|
||||
'f1_mean': df['f1_score'].mean(),
|
||||
'f1_std': df['f1_score'].std(),
|
||||
'precision_mean': df['precision'].mean(),
|
||||
'recall_mean': df['recall'].mean(),
|
||||
}
|
||||
summary.append(metrics)
|
||||
|
||||
df_summary = pd.DataFrame(summary)
|
||||
|
||||
# 打印对比表格
|
||||
print("\n性能对比:")
|
||||
print(df_summary.to_string(index=False))
|
||||
|
||||
# 保存对比结果
|
||||
comparison_dir = os.path.join(output_dir, "point_comparison")
|
||||
os.makedirs(comparison_dir, exist_ok=True)
|
||||
|
||||
csv_path = os.path.join(comparison_dir, "comparison_summary.csv")
|
||||
df_summary.to_csv(csv_path, index=False)
|
||||
print(f"\n对比结果已保存到: {csv_path}")
|
||||
|
||||
# 绘制对比图
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
xs = df_summary["num_points"].tolist()
|
||||
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
|
||||
ax.errorbar(
|
||||
xs,
|
||||
df_summary[mean_col],
|
||||
yerr=df_summary[std_col],
|
||||
marker="o",
|
||||
capsize=5,
|
||||
linewidth=2,
|
||||
markersize=8,
|
||||
)
|
||||
ax.set_xlabel("Number of Points", fontsize=12)
|
||||
|
||||
metrics_to_plot = [
|
||||
('iou_mean', 'iou_std', 'IoU'),
|
||||
('dice_mean', 'dice_std', 'Dice'),
|
||||
('f1_mean', 'f1_std', 'F1-Score')
|
||||
]
|
||||
|
||||
for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot):
|
||||
ax = axes[idx]
|
||||
x = df_summary['num_points']
|
||||
y = df_summary[mean_col]
|
||||
yerr = df_summary[std_col]
|
||||
|
||||
ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8)
|
||||
ax.set_xlabel('Number of Points', fontsize=12)
|
||||
ax.set_ylabel(title, fontsize=12)
|
||||
ax.set_title(f"{title} vs Number of Points", fontsize=14)
|
||||
ax.set_title(f'{title} vs Number of Points', fontsize=14)
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.set_xticks(xs)
|
||||
ax.set_xticks(x)
|
||||
|
||||
plt.tight_layout()
|
||||
plot_path = summary_path.with_name("performance_comparison.png")
|
||||
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
|
||||
|
||||
plot_path = os.path.join(comparison_dir, "performance_comparison.png")
|
||||
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||||
plt.close(fig)
|
||||
|
||||
print(f"对比图已保存到: {plot_path}")
|
||||
|
||||
return df_summary
|
||||
|
||||
|
||||
def main() -> None:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
args = parse_args()
|
||||
if args.task_file:
|
||||
task = load_task_from_toml(args.task_file)
|
||||
TaskRunner(task).run()
|
||||
return
|
||||
|
||||
comparison_data: Dict[int, pd.DataFrame] = {}
|
||||
def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="SAM2 点提示方式 - 多点数对比实验"
|
||||
)
|
||||
|
||||
# 数据集参数
|
||||
parser.add_argument(
|
||||
"--data_root", type=str, default="./crack500",
|
||||
help="数据集根目录"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_file", type=str, default="./crack500/test.txt",
|
||||
help="测试集文件路径"
|
||||
)
|
||||
|
||||
# 模型参数
|
||||
parser.add_argument(
|
||||
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
|
||||
help="SAM2 模型检查点路径"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_cfg", type=str, default="sam2.1_hiera_s.yaml",
|
||||
help="SAM2 模型配置文件"
|
||||
)
|
||||
|
||||
# 实验参数
|
||||
parser.add_argument(
|
||||
"--point_configs", type=int, nargs='+', default=[1, 3, 5],
|
||||
help="要测试的点数配置"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_component", action="store_true",
|
||||
help="为每个连通域独立采样"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip_comparison", action="store_true",
|
||||
help="跳过对比分析"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=" * 80)
|
||||
print("SAM2 点提示方式 - 多点数对比实验")
|
||||
print("=" * 80)
|
||||
print(f"数据集根目录: {args.data_root}")
|
||||
print(f"测试集文件: {args.test_file}")
|
||||
print(f"模型检查点: {args.checkpoint}")
|
||||
print(f"点数配置: {args.point_configs}")
|
||||
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||||
print("=" * 80)
|
||||
|
||||
# 检查 CUDA
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||
else:
|
||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# 运行所有实验
|
||||
results_dict = {}
|
||||
|
||||
for num_points in args.point_configs:
|
||||
output_dir = default_output_dir(num_points, args.per_component)
|
||||
task = build_task_for_points(args, num_points, output_dir)
|
||||
if not task.steps:
|
||||
try:
|
||||
df_results = run_single_experiment(
|
||||
data_root=args.data_root,
|
||||
test_file=args.test_file,
|
||||
checkpoint=args.checkpoint,
|
||||
model_cfg=args.model_cfg,
|
||||
num_points=num_points,
|
||||
per_component=args.per_component
|
||||
)
|
||||
results_dict[num_points] = df_results
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n实验失败 ({num_points} 个点): {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
TaskRunner(task).run()
|
||||
if not args.skip_comparison and not args.skip_evaluation:
|
||||
df = load_results_csv(output_dir)
|
||||
if df is not None:
|
||||
comparison_data[num_points] = df
|
||||
if not args.skip_comparison and comparison_data:
|
||||
compare_results(comparison_data, args.comparison_dir)
|
||||
|
||||
# 对比分析
|
||||
if not args.skip_comparison and len(results_dict) > 1:
|
||||
try:
|
||||
compare_results(results_dict)
|
||||
except Exception as e:
|
||||
print(f"\n对比分析失败: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# 完成
|
||||
print("\n" + "=" * 80)
|
||||
print("所有实验完成!")
|
||||
print("=" * 80)
|
||||
print("\n结果目录:")
|
||||
for num_points in args.point_configs:
|
||||
if args.per_component:
|
||||
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
|
||||
else:
|
||||
output_dir = f"./results/point_prompt_{num_points}pts"
|
||||
print(f" - {num_points} 个点: {output_dir}")
|
||||
|
||||
if not args.skip_comparison and len(results_dict) > 1:
|
||||
print(f" - 对比分析: ./results/point_comparison")
|
||||
|
||||
print("=" * 80)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,39 +1,158 @@
|
||||
"""
|
||||
边界框提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
||||
边界框提示方式的 SAM2 裂缝分割实现
|
||||
从 GT 掩码中提取边界框,使用 SAM2 进行分割
|
||||
"""
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from typing import List, Tuple, Dict
|
||||
from tqdm import tqdm
|
||||
import json
|
||||
import cv2
|
||||
|
||||
from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask
|
||||
from .hf_sam2_predictor import HFSam2Predictor
|
||||
from .model.inference import predict_with_bbox_prompt
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
|
||||
def extract_bboxes_from_mask(mask: np.ndarray, expand_ratio: float = 0.0) -> List[np.ndarray]:
|
||||
"""
|
||||
从二值掩码中提取所有连通域的边界框
|
||||
|
||||
Args:
|
||||
mask: 二值掩码 (H, W),值为 0 或 255
|
||||
expand_ratio: 边界框扩展比例,例如 0.05 表示扩展 5%
|
||||
|
||||
Returns:
|
||||
List of bounding boxes in format [x1, y1, x2, y2]
|
||||
"""
|
||||
# 确保掩码是二值的
|
||||
binary_mask = (mask > 0).astype(np.uint8)
|
||||
|
||||
# 连通域分析
|
||||
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||
binary_mask, connectivity=8
|
||||
)
|
||||
|
||||
bboxes = []
|
||||
# 跳过背景 (label 0)
|
||||
for i in range(1, num_labels):
|
||||
x, y, w, h, area = stats[i]
|
||||
|
||||
# 过滤太小的区域(可能是噪声)
|
||||
if area < 10: # 最小面积阈值
|
||||
continue
|
||||
|
||||
# 计算边界框
|
||||
x1, y1 = x, y
|
||||
x2, y2 = x + w, y + h
|
||||
|
||||
# 扩展边界框(如果需要)
|
||||
if expand_ratio > 0:
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
w_new = w * (1 + expand_ratio)
|
||||
h_new = h * (1 + expand_ratio)
|
||||
x1 = max(0, int(cx - w_new / 2))
|
||||
y1 = max(0, int(cy - h_new / 2))
|
||||
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
||||
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
||||
|
||||
bboxes.append(np.array([x1, y1, x2, y2]))
|
||||
|
||||
return bboxes
|
||||
|
||||
|
||||
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
加载图像和掩码
|
||||
|
||||
Args:
|
||||
image_path: 图像路径
|
||||
mask_path: 掩码路径
|
||||
|
||||
Returns:
|
||||
image: RGB 图像 (H, W, 3)
|
||||
mask: 二值掩码 (H, W)
|
||||
"""
|
||||
# 加载图像
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"无法加载图像: {image_path}")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 加载掩码
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
if mask is None:
|
||||
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
def predict_with_bbox_prompt(
|
||||
predictor: SAM2ImagePredictor,
|
||||
image: np.ndarray,
|
||||
bboxes: List[np.ndarray]
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用边界框提示进行 SAM2 预测
|
||||
|
||||
Args:
|
||||
predictor: SAM2ImagePredictor 实例
|
||||
image: RGB 图像 (H, W, 3)
|
||||
bboxes: 边界框列表,每个为 [x1, y1, x2, y2]
|
||||
|
||||
Returns:
|
||||
combined_mask: 合并后的预测掩码 (H, W)
|
||||
"""
|
||||
# 设置图像
|
||||
predictor.set_image(image)
|
||||
|
||||
# 如果没有边界框,返回空掩码
|
||||
if len(bboxes) == 0:
|
||||
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
|
||||
# 合并所有预测掩码
|
||||
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
|
||||
# 对每个边界框进行预测
|
||||
for bbox in bboxes:
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=bbox[None, :], # shape: (1, 4)
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
# 取第一个掩码(因为 multimask_output=False)
|
||||
mask = masks[0] # shape: (H, W)
|
||||
|
||||
# 合并到总掩码中
|
||||
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
|
||||
|
||||
# 转换为 0-255
|
||||
combined_mask = combined_mask * 255
|
||||
|
||||
return combined_mask
|
||||
|
||||
|
||||
def process_test_set(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
predictor: HFSam2Predictor,
|
||||
predictor: SAM2ImagePredictor,
|
||||
output_dir: str,
|
||||
expand_ratio: float = 0.0
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
处理整个测试集
|
||||
|
||||
|
||||
Args:
|
||||
data_root: 数据集根目录
|
||||
test_file: 测试集文件路径 (test.txt)
|
||||
predictor: HFSam2Predictor 实例
|
||||
predictor: SAM2ImagePredictor 实例
|
||||
output_dir: 输出目录
|
||||
expand_ratio: 边界框扩展比例
|
||||
|
||||
|
||||
Returns:
|
||||
results: 包含每个样本信息的列表
|
||||
"""
|
||||
@ -41,26 +160,26 @@ def process_test_set(
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pred_dir = os.path.join(output_dir, "predictions")
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 读取测试集文件
|
||||
with open(test_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
print(f"开始处理 {len(lines)} 张测试图像...")
|
||||
|
||||
|
||||
for line in tqdm(lines, desc="处理测试集"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
|
||||
img_rel_path, mask_rel_path = parts
|
||||
|
||||
|
||||
# 构建完整路径
|
||||
img_path = os.path.join(data_root, img_rel_path)
|
||||
mask_path = os.path.join(data_root, mask_rel_path)
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(img_path):
|
||||
print(f"警告: 图像不存在 {img_path}")
|
||||
@ -68,23 +187,23 @@ def process_test_set(
|
||||
if not os.path.exists(mask_path):
|
||||
print(f"警告: 掩码不存在 {mask_path}")
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 加载图像和掩码
|
||||
image, mask_gt = load_image_and_mask(img_path, mask_path)
|
||||
|
||||
|
||||
# 从 GT 掩码提取边界框
|
||||
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
|
||||
|
||||
|
||||
# 使用 SAM2 预测
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
|
||||
|
||||
|
||||
# 保存预测掩码
|
||||
img_name = Path(img_rel_path).stem
|
||||
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||
cv2.imwrite(pred_path, mask_pred)
|
||||
|
||||
|
||||
# 记录结果
|
||||
results.append({
|
||||
"image_path": img_rel_path,
|
||||
@ -93,23 +212,20 @@ def process_test_set(
|
||||
"num_bboxes": len(bboxes),
|
||||
"image_shape": image.shape[:2],
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理失败 {img_path}: {str(e)}")
|
||||
# print stack trace
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
# 保存结果信息
|
||||
results_file = os.path.join(output_dir, "results_info.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
print(f"\n处理完成!共处理 {len(results)} 张图像")
|
||||
print(f"预测掩码保存在: {pred_dir}")
|
||||
print(f"结果信息保存在: {results_file}")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@ -118,36 +234,38 @@ def main():
|
||||
# 配置参数
|
||||
DATA_ROOT = "./crack500"
|
||||
TEST_FILE = "./crack500/test.txt"
|
||||
OUTPUT_DIR = "./results/bbox_prompt_hf"
|
||||
|
||||
# HuggingFace SAM2 模型
|
||||
MODEL_ID = "facebook/sam2-hiera-small"
|
||||
|
||||
OUTPUT_DIR = "./results/bbox_prompt"
|
||||
|
||||
# SAM2 模型配置
|
||||
CHECKPOINT = "./sam2/checkpoints/sam2.1_hiera_small.pt"
|
||||
MODEL_CFG = "sam2.1_hiera_s.yaml"
|
||||
|
||||
# 边界框扩展比例
|
||||
EXPAND_RATIO = 0.05 # 5% 扩展
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||
print("SAM2 边界框提示方式 - Crack500 数据集评估")
|
||||
print("=" * 60)
|
||||
print(f"数据集根目录: {DATA_ROOT}")
|
||||
print(f"测试集文件: {TEST_FILE}")
|
||||
print(f"模型: {MODEL_ID}")
|
||||
print(f"模型检查点: {CHECKPOINT}")
|
||||
print(f"模型配置: {MODEL_CFG}")
|
||||
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
|
||||
print(f"输出目录: {OUTPUT_DIR}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 检查 CUDA 是否可用
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||
else:
|
||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# 构建 SAM2 predictor
|
||||
|
||||
# 构建 SAM2 模型
|
||||
print("\n加载 SAM2 模型...")
|
||||
from .hf_sam2_predictor import build_hf_sam2_predictor
|
||||
predictor = build_hf_sam2_predictor(model_id=MODEL_ID)
|
||||
sam2_model = build_sam2(MODEL_CFG, CHECKPOINT)
|
||||
predictor = SAM2ImagePredictor(sam2_model)
|
||||
print("模型加载完成!")
|
||||
|
||||
|
||||
# 处理测试集
|
||||
results = process_test_set(
|
||||
data_root=DATA_ROOT,
|
||||
@ -156,7 +274,7 @@ def main():
|
||||
output_dir=OUTPUT_DIR,
|
||||
expand_ratio=EXPAND_RATIO
|
||||
)
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("处理完成!接下来请运行评估脚本计算指标。")
|
||||
print("=" * 60)
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples
|
||||
from .registry import DatasetRegistry
|
||||
from .utils import extract_bboxes_from_mask, load_image_and_mask
|
||||
|
||||
# ensure built-in datasets register themselves
|
||||
from . import crack500 # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"DatasetRecord",
|
||||
"ModelReadySample",
|
||||
"collate_samples",
|
||||
"DatasetRegistry",
|
||||
"extract_bboxes_from_mask",
|
||||
"load_image_and_mask",
|
||||
]
|
||||
@ -1,167 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from ..model_configuration.config import DatasetConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetRecord:
|
||||
"""
|
||||
Lightweight description of a single sample on disk.
|
||||
"""
|
||||
|
||||
image_path: Path
|
||||
mask_path: Optional[Path] = None
|
||||
prompt_path: Optional[Path] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelReadySample:
|
||||
"""
|
||||
Standard container that mirrors what Hugging Face pipelines expect.
|
||||
"""
|
||||
|
||||
pixel_values: torch.Tensor | np.ndarray
|
||||
prompts: Dict[str, Any] = field(default_factory=dict)
|
||||
labels: Dict[str, Any] = field(default_factory=dict)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_hf_dict(self) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"pixel_values": self.pixel_values,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
if self.prompts:
|
||||
payload["prompts"] = self.prompts
|
||||
if self.labels:
|
||||
payload["labels"] = self.labels
|
||||
return payload
|
||||
|
||||
|
||||
class BaseDataset(Dataset):
|
||||
"""
|
||||
Common dataset base class that handles record bookkeeping, IO, and
|
||||
formatting tensors for Hugging Face pipelines.
|
||||
"""
|
||||
|
||||
dataset_name: str = "base"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DatasetConfig,
|
||||
transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None,
|
||||
return_hf_dict: bool = True,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.transforms = transforms
|
||||
self.return_hf_dict = return_hf_dict
|
||||
self.records: List[DatasetRecord] = self.load_records()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.records)
|
||||
|
||||
def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample:
|
||||
record = self.records[index]
|
||||
sample = self.prepare_sample(record)
|
||||
if self.transforms:
|
||||
sample = self.transforms(sample)
|
||||
return sample.to_hf_dict() if self.return_hf_dict else sample
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_records(self) -> List[DatasetRecord]:
|
||||
"""
|
||||
Scan the dataset directory / annotation files and return
|
||||
structured references to each item on disk.
|
||||
"""
|
||||
|
||||
def prepare_sample(self, record: DatasetRecord) -> ModelReadySample:
|
||||
"""
|
||||
Load image/mask/prompt data from disk and wrap it inside ModelReadySample.
|
||||
Subclasses can override this to implement custom augmentations or prompt generation.
|
||||
"""
|
||||
image = self._load_image(record.image_path)
|
||||
mask = (
|
||||
self._load_mask(record.mask_path)
|
||||
if record.mask_path is not None
|
||||
else None
|
||||
)
|
||||
prompts = self.build_prompts(record, mask)
|
||||
labels = {"mask": mask} if mask is not None else {}
|
||||
sample = ModelReadySample(
|
||||
pixel_values=image,
|
||||
prompts=prompts,
|
||||
labels=labels,
|
||||
metadata=record.metadata,
|
||||
)
|
||||
return sample
|
||||
|
||||
def build_prompts(
|
||||
self, record: DatasetRecord, mask: Optional[np.ndarray]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Derive prompts from metadata or masks.
|
||||
Default implementation extracts bounding boxes from masks.
|
||||
"""
|
||||
if mask is None:
|
||||
return {}
|
||||
boxes = self._mask_to_bboxes(mask)
|
||||
return {"boxes": boxes}
|
||||
|
||||
def _load_image(self, path: Path) -> np.ndarray:
|
||||
image = Image.open(path).convert("RGB")
|
||||
return np.array(image)
|
||||
|
||||
def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]:
|
||||
if path is None:
|
||||
return None
|
||||
mask = Image.open(path).convert("L")
|
||||
return np.array(mask)
|
||||
|
||||
def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]:
|
||||
"""
|
||||
Helper that mirrors the legacy bbox extraction pipeline.
|
||||
"""
|
||||
if mask.ndim != 2:
|
||||
raise ValueError("Mask must be 2-dimensional.")
|
||||
ys, xs = np.where(mask > 0)
|
||||
if ys.size == 0:
|
||||
return []
|
||||
x_min, x_max = xs.min(), xs.max()
|
||||
y_min, y_max = ys.min(), ys.max()
|
||||
return [[int(x_min), int(y_min), int(x_max), int(y_max)]]
|
||||
|
||||
|
||||
def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]:
|
||||
"""
|
||||
Default collate_fn that merges ModelReadySample/HF dict outputs.
|
||||
"""
|
||||
pixel_values = []
|
||||
prompts: List[Dict[str, Any]] = []
|
||||
labels: List[Dict[str, Any]] = []
|
||||
metadata: List[Dict[str, Any]] = []
|
||||
for item in batch:
|
||||
if isinstance(item, ModelReadySample):
|
||||
payload = item.to_hf_dict()
|
||||
else:
|
||||
payload = item
|
||||
pixel_values.append(payload["pixel_values"])
|
||||
prompts.append(payload.get("prompts", {}))
|
||||
labels.append(payload.get("labels", {}))
|
||||
metadata.append(payload.get("metadata", {}))
|
||||
stacked = {
|
||||
"pixel_values": torch.as_tensor(np.stack(pixel_values)),
|
||||
"prompts": prompts,
|
||||
"labels": labels,
|
||||
"metadata": metadata,
|
||||
}
|
||||
return stacked
|
||||
@ -1,99 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseDataset, DatasetRecord
|
||||
from .registry import DatasetRegistry
|
||||
from .utils import (
|
||||
extract_bboxes_from_mask,
|
||||
sample_points_on_skeleton,
|
||||
sample_points_per_component,
|
||||
)
|
||||
from ..model_configuration.config import DatasetConfig
|
||||
|
||||
|
||||
@DatasetRegistry.register("crack500")
|
||||
class Crack500Dataset(BaseDataset):
|
||||
"""
|
||||
Reference implementation that loads Crack500 samples from an image list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DatasetConfig,
|
||||
expand_ratio: float = 0.05,
|
||||
min_area: int = 10,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
extra = dict(config.extra_params or {})
|
||||
expand_ratio = float(extra.get("expand_ratio", expand_ratio))
|
||||
self.prompt_mode = extra.get("prompt_mode", "bbox")
|
||||
self.num_points = int(extra.get("num_points", 5))
|
||||
self.per_component = bool(extra.get("per_component", False))
|
||||
self.expand_ratio = expand_ratio
|
||||
self.min_area = min_area
|
||||
super().__init__(config, **kwargs)
|
||||
|
||||
def load_records(self) -> List[DatasetRecord]:
|
||||
base_dir = Path(self.config.data_root)
|
||||
list_file = (
|
||||
Path(self.config.annotation_file)
|
||||
if self.config.annotation_file
|
||||
else base_dir / (self.config.split_file or "test.txt")
|
||||
)
|
||||
if not list_file.exists():
|
||||
raise FileNotFoundError(f"Missing Crack500 split file: {list_file}")
|
||||
image_dir = base_dir / (self.config.image_folder or "testcrop")
|
||||
mask_dir = base_dir / (self.config.mask_folder or "testdata")
|
||||
records: List[DatasetRecord] = []
|
||||
with list_file.open("r", encoding="utf-8") as handle:
|
||||
for line in handle:
|
||||
image_name = line.strip()
|
||||
if not image_name:
|
||||
continue
|
||||
image_path = image_dir / image_name
|
||||
mask_name = image_name.replace(".jpg", ".png")
|
||||
mask_path = mask_dir / mask_name
|
||||
metadata = {"split": self.config.split, "image_name": image_name}
|
||||
records.append(
|
||||
DatasetRecord(
|
||||
image_path=image_path,
|
||||
mask_path=mask_path if mask_path.exists() else None,
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
if not records:
|
||||
raise RuntimeError(
|
||||
f"No records found in {image_dir} for split {self.config.split}"
|
||||
)
|
||||
return records
|
||||
|
||||
def build_prompts(
|
||||
self,
|
||||
record: DatasetRecord,
|
||||
mask: Optional[np.ndarray],
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
if mask is None:
|
||||
return {}
|
||||
if self.prompt_mode == "point":
|
||||
points, point_labels = self._build_point_prompts(mask)
|
||||
if points.size == 0:
|
||||
return {}
|
||||
prompts: Dict[str, List[List[int]]] = {"points": points.tolist()}
|
||||
if point_labels.size > 0:
|
||||
prompts["point_labels"] = point_labels.tolist()
|
||||
return prompts
|
||||
boxes = extract_bboxes_from_mask(
|
||||
mask, expand_ratio=self.expand_ratio, min_area=self.min_area
|
||||
)
|
||||
return {"boxes": boxes}
|
||||
|
||||
def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if self.per_component:
|
||||
return sample_points_per_component(mask, self.num_points)
|
||||
points = sample_points_on_skeleton(mask, self.num_points)
|
||||
labels = np.ones(points.shape[0], dtype=np.int32)
|
||||
return points, labels
|
||||
@ -1,33 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from .base import BaseDataset
|
||||
|
||||
|
||||
class DatasetRegistry:
|
||||
"""
|
||||
Simple registry so configs can refer to datasets by string key.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseDataset]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]:
|
||||
cls._registry[name] = dataset_cls
|
||||
dataset_cls.dataset_name = name
|
||||
return dataset_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, *args, **kwargs) -> BaseDataset:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"Dataset '{name}' is not registered.")
|
||||
dataset_cls = cls._registry[name]
|
||||
return dataset_cls(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, Type[BaseDataset]]:
|
||||
return dict(cls._registry)
|
||||
@ -1,91 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
|
||||
def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Reads an RGB image and its mask counterpart.
|
||||
"""
|
||||
image_path = str(image_path)
|
||||
mask_path = str(mask_path)
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
raise ValueError(f"无法加载图像: {image_path}")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
if mask is None:
|
||||
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||
return image, mask
|
||||
|
||||
|
||||
def extract_bboxes_from_mask(
|
||||
mask: np.ndarray,
|
||||
expand_ratio: float = 0.0,
|
||||
min_area: int = 10,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Extract bounding boxes from a binary mask using connected components.
|
||||
"""
|
||||
binary_mask = (mask > 0).astype(np.uint8)
|
||||
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
|
||||
bboxes: List[List[int]] = []
|
||||
for i in range(1, num_labels):
|
||||
x, y, w, h, area = stats[i]
|
||||
if area < min_area:
|
||||
continue
|
||||
x1, y1 = x, y
|
||||
x2, y2 = x + w, y + h
|
||||
if expand_ratio > 0:
|
||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||
w_new = w * (1 + expand_ratio)
|
||||
h_new = h * (1 + expand_ratio)
|
||||
x1 = max(0, int(cx - w_new / 2))
|
||||
y1 = max(0, int(cy - h_new / 2))
|
||||
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
||||
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
||||
bboxes.append([x1, y1, x2, y2])
|
||||
return bboxes
|
||||
|
||||
|
||||
def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray:
|
||||
"""
|
||||
Sample points uniformly along the mask skeleton in (x, y) order.
|
||||
"""
|
||||
binary_mask = (mask > 0).astype(bool)
|
||||
try:
|
||||
skeleton = skeletonize(binary_mask)
|
||||
except Exception:
|
||||
skeleton = binary_mask
|
||||
coords = np.argwhere(skeleton)
|
||||
if coords.size == 0:
|
||||
return np.zeros((0, 2), dtype=np.int32)
|
||||
if coords.shape[0] <= num_points:
|
||||
points = coords[:, [1, 0]]
|
||||
return points.astype(np.int32)
|
||||
indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int)
|
||||
sampled = coords[indices][:, [1, 0]]
|
||||
return sampled.astype(np.int32)
|
||||
|
||||
|
||||
def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Sample points per connected component along each component's skeleton.
|
||||
"""
|
||||
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
||||
all_points = []
|
||||
for region_id in range(1, num_labels):
|
||||
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
||||
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
||||
if len(points):
|
||||
all_points.append(points)
|
||||
if not all_points:
|
||||
return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32)
|
||||
stacked = np.vstack(all_points)
|
||||
labels = np.ones(stacked.shape[0], dtype=np.int32)
|
||||
return stacked, labels
|
||||
@ -1,14 +0,0 @@
|
||||
from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall
|
||||
from .pipeline_eval import PipelineEvaluator
|
||||
from .reporting import write_csv, write_json
|
||||
|
||||
__all__ = [
|
||||
"METRIC_REGISTRY",
|
||||
"PipelineEvaluator",
|
||||
"compute_dice",
|
||||
"compute_iou",
|
||||
"compute_precision",
|
||||
"compute_recall",
|
||||
"write_csv",
|
||||
"write_json",
|
||||
]
|
||||
@ -1,57 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable, Dict, Iterable, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
intersection = (pred_bin & target_bin).sum()
|
||||
union = (pred_bin | target_bin).sum()
|
||||
return float(intersection / union) if union else 0.0
|
||||
|
||||
|
||||
def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
intersection = (pred_bin & target_bin).sum()
|
||||
total = pred_bin.sum() + target_bin.sum()
|
||||
return float((2 * intersection) / total) if total else 0.0
|
||||
|
||||
|
||||
def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
tp = (pred_bin & target_bin).sum()
|
||||
fp = (pred_bin & (1 - target_bin)).sum()
|
||||
return float(tp / (tp + fp)) if (tp + fp) else 0.0
|
||||
|
||||
|
||||
def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
||||
target_bin = (target > 0).astype(np.uint8)
|
||||
tp = (pred_bin & target_bin).sum()
|
||||
fn = ((1 - pred_bin) & target_bin).sum()
|
||||
return float(tp / (tp + fn)) if (tp + fn) else 0.0
|
||||
|
||||
|
||||
MetricFn = Callable[[np.ndarray, np.ndarray, float], float]
|
||||
|
||||
|
||||
METRIC_REGISTRY: Dict[str, MetricFn] = {
|
||||
"iou": compute_iou,
|
||||
"dice": compute_dice,
|
||||
"precision": compute_precision,
|
||||
"recall": compute_recall,
|
||||
}
|
||||
|
||||
|
||||
def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]:
|
||||
resolved: Dict[str, MetricFn] = {}
|
||||
for name in metric_names:
|
||||
if name not in METRIC_REGISTRY:
|
||||
raise KeyError(f"Metric '{name}' is not registered.")
|
||||
resolved[name] = METRIC_REGISTRY[name]
|
||||
return resolved
|
||||
@ -1,95 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..dataset import BaseDataset
|
||||
from ..model import BaseModelAdapter
|
||||
from ..model_configuration import EvaluationConfig
|
||||
from .metrics import resolve_metrics
|
||||
from .utils import extract_mask_from_pipeline_output
|
||||
|
||||
|
||||
class PipelineEvaluator:
|
||||
"""
|
||||
Runs a Hugging Face pipeline across a dataset and aggregates metrics.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: BaseDataset,
|
||||
adapter: BaseModelAdapter,
|
||||
config: EvaluationConfig,
|
||||
) -> None:
|
||||
self.dataset = dataset
|
||||
self.adapter = adapter
|
||||
self.config = config
|
||||
self.metrics = resolve_metrics(config.metrics)
|
||||
|
||||
def run(self) -> Dict[str, Any]:
|
||||
pipe = self.adapter.build_pipeline()
|
||||
aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics}
|
||||
output_dir = Path(self.config.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
requested = self.config.max_samples or len(self.dataset)
|
||||
total = min(requested, len(self.dataset))
|
||||
prog_bar = tqdm(range(total), total=total)
|
||||
for idx in prog_bar:
|
||||
sample = self.dataset[idx]
|
||||
inputs = self._build_pipeline_inputs(sample)
|
||||
preds = pipe(**inputs)
|
||||
labels = sample.get("labels", {})
|
||||
mask = labels.get("mask")
|
||||
if mask is None:
|
||||
continue
|
||||
prediction_mask = self._extract_mask(preds)
|
||||
for metric_name, metric_fn in self.metrics.items():
|
||||
for threshold in self.config.thresholds:
|
||||
value = metric_fn(prediction_mask, mask, threshold)
|
||||
aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value)
|
||||
if self.config.save_predictions:
|
||||
self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"])
|
||||
summary = {
|
||||
"metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()},
|
||||
"config": self.config.__dict__,
|
||||
"num_samples": total,
|
||||
}
|
||||
with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle:
|
||||
json.dump(summary, handle, indent=2)
|
||||
return summary
|
||||
|
||||
def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
||||
inputs: Dict[str, Any] = {"images": sample["pixel_values"]}
|
||||
prompts = sample.get("prompts", {})
|
||||
if "boxes" in prompts and prompts["boxes"]:
|
||||
inputs["boxes"] = prompts["boxes"]
|
||||
if "points" in prompts and prompts["points"]:
|
||||
inputs["points"] = prompts["points"]
|
||||
if "point_labels" in prompts and prompts["point_labels"]:
|
||||
inputs["point_labels"] = prompts["point_labels"]
|
||||
return inputs
|
||||
|
||||
def _extract_mask(self, pipeline_output: Any) -> np.ndarray:
|
||||
"""
|
||||
Normalize pipeline outputs into numpy masks.
|
||||
"""
|
||||
return extract_mask_from_pipeline_output(pipeline_output)
|
||||
|
||||
def _write_prediction(
|
||||
self,
|
||||
output_dir: Path,
|
||||
index: int,
|
||||
mask: np.ndarray,
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
if metadata and "image_name" in metadata:
|
||||
filename = metadata["image_name"].replace(".jpg", "_pred.npy")
|
||||
else:
|
||||
filename = f"sample_{index:04d}_pred.npy"
|
||||
target_path = output_dir / "predictions"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
np.save(target_path / filename, mask)
|
||||
@ -1,25 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable
|
||||
|
||||
|
||||
def write_json(summary: Dict, output_path: Path) -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("w", encoding="utf-8") as handle:
|
||||
json.dump(summary, handle, indent=2)
|
||||
|
||||
|
||||
def write_csv(rows: Iterable[Dict], output_path: Path) -> None:
|
||||
rows = list(rows)
|
||||
if not rows:
|
||||
return
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
fieldnames = sorted(rows[0].keys())
|
||||
with output_path.open("w", encoding="utf-8", newline="") as handle:
|
||||
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for row in rows:
|
||||
writer.writerow(row)
|
||||
@ -1,55 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..model import ModelRegistry
|
||||
from ..model_configuration import ConfigRegistry, EvaluationConfig
|
||||
from .pipeline_eval import PipelineEvaluator
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineCLIArguments:
|
||||
config_name: str = "sam2_bbox_prompt"
|
||||
model_key: str = "sam2"
|
||||
split: str = "test"
|
||||
split_file: Optional[str] = None
|
||||
device: Optional[str] = None
|
||||
max_samples: Optional[int] = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(PipelineCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
||||
dataset = DatasetRegistry.create(
|
||||
dataset_cfg.name,
|
||||
config=dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||
evaluation_config = replace(
|
||||
project_config.evaluation,
|
||||
max_samples=cli_args.max_samples,
|
||||
)
|
||||
if cli_args.device:
|
||||
adapter.build_pipeline(device=cli_args.device)
|
||||
evaluator = PipelineEvaluator(
|
||||
dataset=dataset,
|
||||
adapter=adapter,
|
||||
config=evaluation_config,
|
||||
)
|
||||
summary = evaluator.run()
|
||||
LOGGER.info("Evaluation summary: %s", summary)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
@ -1,16 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray:
|
||||
if isinstance(pipeline_output, list):
|
||||
pipeline_output = pipeline_output[0]
|
||||
mask = pipeline_output.get("mask")
|
||||
if mask is None:
|
||||
raise ValueError("Pipeline output missing 'mask'.")
|
||||
if isinstance(mask, np.ndarray):
|
||||
return mask
|
||||
return np.array(mask)
|
||||
@ -1,7 +0,0 @@
|
||||
"""
|
||||
Backward-compatible wrapper that re-exports the predictor relocated to src.model.
|
||||
"""
|
||||
|
||||
from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor
|
||||
|
||||
__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"]
|
||||
@ -1,17 +0,0 @@
|
||||
from .base import BaseModelAdapter
|
||||
from .inference import predict_with_bbox_prompt
|
||||
from .predictor import HFSam2Predictor, build_hf_sam2_predictor
|
||||
from .registry import ModelRegistry
|
||||
from .sam2_adapter import Sam2ModelAdapter
|
||||
from .trainer import FineTuningTrainer, TrainerArtifacts
|
||||
|
||||
__all__ = [
|
||||
"BaseModelAdapter",
|
||||
"FineTuningTrainer",
|
||||
"HFSam2Predictor",
|
||||
"ModelRegistry",
|
||||
"Sam2ModelAdapter",
|
||||
"TrainerArtifacts",
|
||||
"build_hf_sam2_predictor",
|
||||
"predict_with_bbox_prompt",
|
||||
]
|
||||
@ -1,66 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from transformers import pipeline
|
||||
|
||||
from ..model_configuration import ModelConfig
|
||||
|
||||
|
||||
class BaseModelAdapter(abc.ABC):
|
||||
"""
|
||||
Thin wrapper that standardizes how we instantiate models/processors/pipelines.
|
||||
"""
|
||||
|
||||
task: str = "image-segmentation"
|
||||
|
||||
def __init__(self, config: ModelConfig) -> None:
|
||||
self.config = config
|
||||
self._model = None
|
||||
self._processor = None
|
||||
self._pipeline = None
|
||||
|
||||
def load_pretrained(self):
|
||||
if self._model is None or self._processor is None:
|
||||
self._model, self._processor = self._load_pretrained()
|
||||
return self._model, self._processor
|
||||
|
||||
def build_pipeline(
|
||||
self,
|
||||
device: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if self._pipeline is None:
|
||||
model, processor = self.load_pretrained()
|
||||
pipe_kwargs = {
|
||||
"task": self.task,
|
||||
"model": model,
|
||||
"image_processor": processor,
|
||||
**self.config.pipeline_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
if device is not None:
|
||||
pipe_kwargs["device"] = device
|
||||
self._pipeline = self._create_pipeline(pipe_kwargs)
|
||||
return self._pipeline
|
||||
|
||||
async def build_pipeline_async(self, **kwargs):
|
||||
"""
|
||||
Async helper for future multi-device orchestration.
|
||||
"""
|
||||
return self.build_pipeline(**kwargs)
|
||||
|
||||
def save_pretrained(self, output_dir: str) -> None:
|
||||
model, processor = self.load_pretrained()
|
||||
model.save_pretrained(output_dir)
|
||||
processor.save_pretrained(output_dir)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _load_pretrained(self):
|
||||
"""
|
||||
Return (model, processor) tuple.
|
||||
"""
|
||||
|
||||
def _create_pipeline(self, pipe_kwargs: Dict[str, Any]):
|
||||
return pipeline(**pipe_kwargs)
|
||||
@ -1,32 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .predictor import HFSam2Predictor
|
||||
|
||||
|
||||
def predict_with_bbox_prompt(
|
||||
predictor: HFSam2Predictor,
|
||||
image: np.ndarray,
|
||||
bboxes: List[np.ndarray],
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Run SAM2 predictions for each bounding box and merge the masks.
|
||||
"""
|
||||
predictor.set_image(image)
|
||||
if not bboxes:
|
||||
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
for bbox in bboxes:
|
||||
masks, _, _ = predictor.predict(
|
||||
point_coords=None,
|
||||
point_labels=None,
|
||||
box=bbox,
|
||||
multimask_output=False,
|
||||
)
|
||||
mask = masks[0]
|
||||
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
|
||||
combined_mask = combined_mask * 255
|
||||
return combined_mask
|
||||
@ -1,158 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import SamModel, SamProcessor
|
||||
|
||||
|
||||
class HFSam2Predictor:
|
||||
"""
|
||||
Predictor wrapper around Hugging Face SAM2 models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "facebook/sam2-hiera-small",
|
||||
device: Optional[str] = None,
|
||||
dtype: torch.dtype = torch.bfloat16,
|
||||
) -> None:
|
||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.dtype = dtype
|
||||
self.model = SamModel.from_pretrained(model_id).to(self.device)
|
||||
self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json")
|
||||
self._override_processor_config()
|
||||
if dtype == torch.bfloat16:
|
||||
self.model = self.model.to(dtype=dtype)
|
||||
self.model.eval()
|
||||
self.current_image = None
|
||||
self.current_image_embeddings = None
|
||||
|
||||
def set_image(self, image: np.ndarray) -> None:
|
||||
if isinstance(image, np.ndarray):
|
||||
pil_image = Image.fromarray(image.astype(np.uint8))
|
||||
else:
|
||||
pil_image = image
|
||||
self.current_image = pil_image
|
||||
with torch.inference_mode():
|
||||
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
|
||||
if self.dtype == torch.bfloat16:
|
||||
inputs = {
|
||||
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
|
||||
|
||||
def predict(
|
||||
self,
|
||||
point_coords: Optional[np.ndarray] = None,
|
||||
point_labels: Optional[np.ndarray] = None,
|
||||
box: Optional[np.ndarray] = None,
|
||||
multimask_output: bool = False,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
if self.current_image is None:
|
||||
raise ValueError("No image set. Call set_image() first.")
|
||||
input_points = self._prepare_points(point_coords)
|
||||
input_labels = self._prepare_labels(point_labels)
|
||||
input_boxes = self._prepare_boxes(box)
|
||||
with torch.inference_mode():
|
||||
inputs = self.processor(
|
||||
images=self.current_image,
|
||||
input_points=input_points,
|
||||
input_labels=input_labels,
|
||||
input_boxes=input_boxes,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
if self.dtype == torch.bfloat16:
|
||||
inputs = {
|
||||
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
inputs.pop("pixel_values", None)
|
||||
inputs["image_embeddings"] = self.current_image_embeddings
|
||||
outputs = self.model(**inputs, multimask_output=multimask_output)
|
||||
masks = self.processor.image_processor.post_process_masks(
|
||||
outputs.pred_masks.float().cpu(),
|
||||
inputs["original_sizes"].cpu(),
|
||||
inputs["reshaped_input_sizes"].cpu(),
|
||||
)[0]
|
||||
scores = outputs.iou_scores.float().cpu().numpy()[0]
|
||||
masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8)
|
||||
logits = outputs.pred_masks.float().cpu().numpy()[0]
|
||||
return masks_np, scores, logits
|
||||
|
||||
def _prepare_points(self, coords: Optional[np.ndarray]):
|
||||
"""
|
||||
Points must be shaped (num_points, 2); wrap in outer batch dimension.
|
||||
"""
|
||||
if coords is None:
|
||||
return None
|
||||
coords_arr = np.asarray(coords)
|
||||
if coords_arr.ndim == 1:
|
||||
coords_arr = coords_arr[None, :]
|
||||
if coords_arr.ndim != 2:
|
||||
raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.")
|
||||
return [coords_arr.tolist()]
|
||||
|
||||
def _prepare_labels(self, labels: Optional[np.ndarray]):
|
||||
"""
|
||||
Labels mirror the point dimension and are shaped (num_points,).
|
||||
"""
|
||||
if labels is None:
|
||||
return None
|
||||
labels_arr = np.asarray(labels)
|
||||
if labels_arr.ndim == 0:
|
||||
labels_arr = labels_arr[None]
|
||||
if labels_arr.ndim != 1:
|
||||
raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.")
|
||||
return [labels_arr.tolist()]
|
||||
|
||||
def _prepare_boxes(self, boxes: Optional[np.ndarray]):
|
||||
"""
|
||||
HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4).
|
||||
"""
|
||||
if boxes is None:
|
||||
return None
|
||||
boxes_arr = np.asarray(boxes)
|
||||
if boxes_arr.ndim == 1:
|
||||
return [[boxes_arr.tolist()]]
|
||||
if boxes_arr.ndim == 2:
|
||||
return [boxes_arr.tolist()]
|
||||
if boxes_arr.ndim == 3:
|
||||
return boxes_arr.tolist()
|
||||
raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.")
|
||||
|
||||
def _override_processor_config(self) -> None:
|
||||
"""
|
||||
Override processor config with local settings to avoid upstream regressions.
|
||||
"""
|
||||
config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json"
|
||||
if not config_path.exists():
|
||||
return
|
||||
try:
|
||||
config_dict = json.loads(config_path.read_text())
|
||||
except Exception:
|
||||
return
|
||||
image_processor = getattr(self.processor, "image_processor", None)
|
||||
if image_processor is None or not hasattr(image_processor, "config"):
|
||||
return
|
||||
# config behaves like a dict; update in-place.
|
||||
try:
|
||||
image_processor.config.update(config_dict)
|
||||
except Exception:
|
||||
for key, value in config_dict.items():
|
||||
try:
|
||||
setattr(image_processor.config, key, value)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
def build_hf_sam2_predictor(
|
||||
model_id: str = "facebook/sam2-hiera-small",
|
||||
device: Optional[str] = None,
|
||||
) -> HFSam2Predictor:
|
||||
return HFSam2Predictor(model_id=model_id, device=device)
|
||||
@ -1,33 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from ..model_configuration import ModelConfig
|
||||
from .base import BaseModelAdapter
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
"""
|
||||
Maps model keys to adapter classes so configs can reference them declaratively.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[BaseModelAdapter]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str):
|
||||
def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]:
|
||||
cls._registry[name] = adapter_cls
|
||||
return adapter_cls
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"ModelAdapter '{name}' is not registered.")
|
||||
adapter_cls = cls._registry[name]
|
||||
return adapter_cls(config)
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, Type[BaseModelAdapter]]:
|
||||
return dict(cls._registry)
|
||||
@ -1,35 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
from transformers import AutoModelForImageSegmentation, AutoProcessor
|
||||
|
||||
from ..model_configuration import ModelConfig
|
||||
from .base import BaseModelAdapter
|
||||
from .registry import ModelRegistry
|
||||
|
||||
|
||||
@ModelRegistry.register("sam2")
|
||||
class Sam2ModelAdapter(BaseModelAdapter):
|
||||
"""
|
||||
Adapter that exposes SAM2 checkpoints through the HF pipeline interface.
|
||||
"""
|
||||
|
||||
def __init__(self, config: ModelConfig) -> None:
|
||||
super().__init__(config)
|
||||
self.task = "image-segmentation"
|
||||
|
||||
def _load_pretrained(self) -> Tuple[Any, Any]:
|
||||
model = AutoModelForImageSegmentation.from_pretrained(
|
||||
self.config.name_or_path,
|
||||
revision=self.config.revision,
|
||||
cache_dir=self.config.cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
self.config.name_or_path,
|
||||
revision=self.config.revision,
|
||||
cache_dir=self.config.cache_dir,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
return model, processor
|
||||
@ -1,88 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..model_configuration import ConfigRegistry, DatasetConfig
|
||||
from .registry import ModelRegistry
|
||||
from .trainer import FineTuningTrainer
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainCLIArguments:
|
||||
config_name: str = "sam2_bbox_prompt"
|
||||
model_key: str = "sam2"
|
||||
train_split: str = "train"
|
||||
eval_split: str = "val"
|
||||
train_split_file: Optional[str] = None
|
||||
eval_split_file: Optional[str] = None
|
||||
skip_eval: bool = False
|
||||
device: Optional[str] = None
|
||||
|
||||
|
||||
def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
||||
overrides = {}
|
||||
if split:
|
||||
overrides["split"] = split
|
||||
if split_file:
|
||||
overrides["split_file"] = split_file
|
||||
return replace(config, **overrides)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(TrainCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||
train_dataset_cfg = build_dataset(
|
||||
project_config.dataset, cli_args.train_split, cli_args.train_split_file
|
||||
)
|
||||
eval_dataset_cfg = (
|
||||
build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file)
|
||||
if not cli_args.skip_eval
|
||||
else None
|
||||
)
|
||||
|
||||
train_dataset = DatasetRegistry.create(
|
||||
train_dataset_cfg.name,
|
||||
config=train_dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
eval_dataset = (
|
||||
DatasetRegistry.create(
|
||||
eval_dataset_cfg.name,
|
||||
config=eval_dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
if eval_dataset_cfg
|
||||
else None
|
||||
)
|
||||
|
||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||
if cli_args.device:
|
||||
adapter.build_pipeline(device=cli_args.device)
|
||||
|
||||
trainer_builder = FineTuningTrainer(
|
||||
adapter=adapter,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
training_config=project_config.training,
|
||||
)
|
||||
artifacts = trainer_builder.build()
|
||||
LOGGER.info("Starting training with args: %s", artifacts.training_args)
|
||||
train_result = artifacts.trainer.train()
|
||||
LOGGER.info("Training finished: %s", train_result)
|
||||
artifacts.trainer.save_model(project_config.training.output_dir)
|
||||
if eval_dataset and not cli_args.skip_eval:
|
||||
metrics = artifacts.trainer.evaluate()
|
||||
LOGGER.info("Evaluation metrics: %s", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
@ -1,64 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
from ..dataset import BaseDataset, collate_samples
|
||||
from ..model_configuration import TrainingConfig
|
||||
from .base import BaseModelAdapter
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainerArtifacts:
|
||||
trainer: Trainer
|
||||
training_args: TrainingArguments
|
||||
|
||||
|
||||
class FineTuningTrainer:
|
||||
"""
|
||||
Helper that bridges TrainingConfig + datasets + adapters into HF Trainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: BaseModelAdapter,
|
||||
train_dataset: Optional[BaseDataset],
|
||||
eval_dataset: Optional[BaseDataset],
|
||||
training_config: TrainingConfig,
|
||||
trainer_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.adapter = adapter
|
||||
self.train_dataset = train_dataset
|
||||
self.eval_dataset = eval_dataset
|
||||
self.training_config = training_config
|
||||
self.trainer_kwargs = trainer_kwargs or {}
|
||||
|
||||
def build(self) -> TrainerArtifacts:
|
||||
model, processor = self.adapter.load_pretrained()
|
||||
training_args = TrainingArguments(
|
||||
output_dir=self.training_config.output_dir,
|
||||
num_train_epochs=self.training_config.num_train_epochs,
|
||||
per_device_train_batch_size=self.training_config.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=self.training_config.per_device_eval_batch_size,
|
||||
learning_rate=self.training_config.learning_rate,
|
||||
gradient_accumulation_steps=self.training_config.gradient_accumulation_steps,
|
||||
lr_scheduler_type=self.training_config.lr_scheduler_type,
|
||||
warmup_ratio=self.training_config.warmup_ratio,
|
||||
weight_decay=self.training_config.weight_decay,
|
||||
seed=self.training_config.seed,
|
||||
fp16=self.training_config.fp16,
|
||||
bf16=self.training_config.bf16,
|
||||
report_to=self.training_config.report_to,
|
||||
)
|
||||
hf_trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
data_collator=collate_samples,
|
||||
tokenizer=processor,
|
||||
**self.trainer_kwargs,
|
||||
)
|
||||
return TrainerArtifacts(trainer=hf_trainer, training_args=training_args)
|
||||
@ -1,22 +0,0 @@
|
||||
from .config import (
|
||||
DatasetConfig,
|
||||
EvaluationConfig,
|
||||
ModelConfig,
|
||||
ProjectConfig,
|
||||
TrainingConfig,
|
||||
VisualizationConfig,
|
||||
)
|
||||
from .registry import ConfigRegistry
|
||||
|
||||
# ensure example configs register themselves
|
||||
from . import sam2_bbox # noqa: F401
|
||||
|
||||
__all__ = [
|
||||
"DatasetConfig",
|
||||
"EvaluationConfig",
|
||||
"ModelConfig",
|
||||
"ProjectConfig",
|
||||
"TrainingConfig",
|
||||
"VisualizationConfig",
|
||||
"ConfigRegistry",
|
||||
]
|
||||
@ -1,89 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
def _default_dict() -> Dict[str, Any]:
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
name: str
|
||||
data_root: str
|
||||
split: str = "test"
|
||||
split_file: Optional[str] = None
|
||||
annotation_file: Optional[str] = None
|
||||
image_folder: Optional[str] = None
|
||||
mask_folder: Optional[str] = None
|
||||
extra_params: Dict[str, Any] = field(default_factory=_default_dict)
|
||||
|
||||
def resolve_path(self, relative: Optional[str]) -> Optional[Path]:
|
||||
if relative is None:
|
||||
return None
|
||||
return Path(self.data_root) / relative
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
name_or_path: str
|
||||
revision: Optional[str] = None
|
||||
config_name: Optional[str] = None
|
||||
cache_dir: Optional[str] = None
|
||||
prompt_type: str = "bbox"
|
||||
image_size: Optional[int] = None
|
||||
pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
||||
adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingConfig:
|
||||
output_dir: str = "./outputs"
|
||||
num_train_epochs: float = 3.0
|
||||
per_device_train_batch_size: int = 1
|
||||
per_device_eval_batch_size: int = 1
|
||||
learning_rate: float = 1e-4
|
||||
weight_decay: float = 0.0
|
||||
gradient_accumulation_steps: int = 1
|
||||
lr_scheduler_type: str = "linear"
|
||||
warmup_ratio: float = 0.0
|
||||
seed: int = 42
|
||||
fp16: bool = False
|
||||
bf16: bool = False
|
||||
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationConfig:
|
||||
output_dir: str = "./results"
|
||||
metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"])
|
||||
thresholds: List[float] = field(default_factory=lambda: [0.5])
|
||||
max_samples: Optional[int] = None
|
||||
save_predictions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualizationConfig:
|
||||
num_samples: int = 20
|
||||
overlay_alpha: float = 0.6
|
||||
save_dir: str = "./results/visualizations"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProjectConfig:
|
||||
dataset: DatasetConfig
|
||||
model: ModelConfig
|
||||
training: TrainingConfig = field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
||||
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"dataset": self.dataset,
|
||||
"model": self.model,
|
||||
"training": self.training,
|
||||
"evaluation": self.evaluation,
|
||||
"visualization": self.visualization,
|
||||
}
|
||||
@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .config import ProjectConfig
|
||||
|
||||
|
||||
class ConfigRegistry:
|
||||
"""
|
||||
Stores reusable project configurations (dataset + model + training bundle).
|
||||
"""
|
||||
|
||||
_registry: Dict[str, ProjectConfig] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str, config: ProjectConfig) -> ProjectConfig:
|
||||
cls._registry[name] = config
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> ProjectConfig:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"ProjectConfig '{name}' is not registered.")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, ProjectConfig]:
|
||||
return dict(cls._registry)
|
||||
@ -1,47 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .config import (
|
||||
DatasetConfig,
|
||||
EvaluationConfig,
|
||||
ModelConfig,
|
||||
ProjectConfig,
|
||||
TrainingConfig,
|
||||
VisualizationConfig,
|
||||
)
|
||||
from .registry import ConfigRegistry
|
||||
|
||||
|
||||
SAM2_BBOX_CONFIG = ProjectConfig(
|
||||
dataset=DatasetConfig(
|
||||
name="crack500",
|
||||
data_root="./crack500",
|
||||
split="test",
|
||||
split_file="test.txt",
|
||||
image_folder="testcrop",
|
||||
mask_folder="testdata",
|
||||
),
|
||||
model=ModelConfig(
|
||||
name_or_path="facebook/sam2.1-hiera-small",
|
||||
prompt_type="bbox",
|
||||
pipeline_kwargs={"batch_size": 1},
|
||||
),
|
||||
training=TrainingConfig(
|
||||
output_dir="./outputs/sam2_bbox",
|
||||
num_train_epochs=5,
|
||||
per_device_train_batch_size=1,
|
||||
per_device_eval_batch_size=1,
|
||||
learning_rate=1e-4,
|
||||
gradient_accumulation_steps=4,
|
||||
lr_scheduler_type="cosine",
|
||||
),
|
||||
evaluation=EvaluationConfig(
|
||||
output_dir="./results/bbox_prompt",
|
||||
thresholds=[0.3, 0.5, 0.75],
|
||||
),
|
||||
visualization=VisualizationConfig(
|
||||
save_dir="./results/bbox_prompt/visualizations",
|
||||
num_samples=20,
|
||||
),
|
||||
)
|
||||
|
||||
ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG)
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
点提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
||||
点提示方式的 SAM2 裂缝分割实现
|
||||
使用骨架采样策略,支持 1, 3, 5 个点
|
||||
"""
|
||||
|
||||
@ -13,102 +13,103 @@ from tqdm import tqdm
|
||||
import json
|
||||
from skimage.morphology import skeletonize
|
||||
|
||||
from .hf_sam2_predictor import HFSam2Predictor
|
||||
from sam2.build_sam import build_sam2
|
||||
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||
|
||||
|
||||
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
|
||||
"""
|
||||
在骨架上均匀采样点
|
||||
|
||||
|
||||
Args:
|
||||
mask: 二值掩码 (H, W),值为 0 或 255
|
||||
num_points: 采样点数量
|
||||
|
||||
|
||||
Returns:
|
||||
points: 采样点坐标 (N, 2),格式为 [x, y]
|
||||
"""
|
||||
# 确保掩码是二值的
|
||||
binary_mask = (mask > 0).astype(bool)
|
||||
|
||||
|
||||
# 骨架化
|
||||
try:
|
||||
skeleton = skeletonize(binary_mask)
|
||||
except:
|
||||
# 如果骨架化失败,直接使用掩码
|
||||
skeleton = binary_mask
|
||||
|
||||
|
||||
# 获取骨架点坐标 (y, x)
|
||||
skeleton_coords = np.argwhere(skeleton)
|
||||
|
||||
|
||||
if len(skeleton_coords) == 0:
|
||||
# 如果没有骨架点,返回空数组
|
||||
return np.array([]).reshape(0, 2)
|
||||
|
||||
|
||||
if len(skeleton_coords) <= num_points:
|
||||
# 如果骨架点数少于需要的点数,返回所有点
|
||||
# 转换为 (x, y) 格式
|
||||
return skeleton_coords[:, [1, 0]]
|
||||
|
||||
|
||||
# 均匀间隔采样
|
||||
indices = np.linspace(0, len(skeleton_coords) - 1, num_points, dtype=int)
|
||||
sampled_coords = skeleton_coords[indices]
|
||||
|
||||
|
||||
# 转换为 (x, y) 格式
|
||||
points = sampled_coords[:, [1, 0]]
|
||||
|
||||
|
||||
return points
|
||||
|
||||
|
||||
def sample_points_per_component(
|
||||
mask: np.ndarray,
|
||||
mask: np.ndarray,
|
||||
num_points_per_component: int = 3
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
为每个连通域独立采样点
|
||||
|
||||
|
||||
Args:
|
||||
mask: 二值掩码 (H, W)
|
||||
num_points_per_component: 每个连通域的点数
|
||||
|
||||
|
||||
Returns:
|
||||
points: 所有采样点 (N, 2)
|
||||
labels: 点标签,全为 1(正样本)
|
||||
"""
|
||||
# 连通域分析
|
||||
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
||||
|
||||
|
||||
all_points = []
|
||||
|
||||
|
||||
# 跳过背景 (label 0)
|
||||
for region_id in range(1, num_labels):
|
||||
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
||||
|
||||
|
||||
# 对每个连通域采样
|
||||
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
||||
|
||||
|
||||
if len(points) > 0:
|
||||
all_points.append(points)
|
||||
|
||||
|
||||
if len(all_points) == 0:
|
||||
return np.array([]).reshape(0, 2), np.array([])
|
||||
|
||||
|
||||
# 合并所有点
|
||||
all_points = np.vstack(all_points)
|
||||
|
||||
|
||||
# 所有点都是正样本
|
||||
point_labels = np.ones(len(all_points), dtype=np.int32)
|
||||
|
||||
|
||||
return all_points, point_labels
|
||||
|
||||
|
||||
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
加载图像和掩码
|
||||
|
||||
|
||||
Args:
|
||||
image_path: 图像路径
|
||||
mask_path: 掩码路径
|
||||
|
||||
|
||||
Returns:
|
||||
image: RGB 图像 (H, W, 3)
|
||||
mask: 二值掩码 (H, W)
|
||||
@ -118,79 +119,79 @@ def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np
|
||||
if image is None:
|
||||
raise ValueError(f"无法加载图像: {image_path}")
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
|
||||
# 加载掩码
|
||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||
if mask is None:
|
||||
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
def predict_with_point_prompt(
|
||||
predictor: HFSam2Predictor,
|
||||
predictor: SAM2ImagePredictor,
|
||||
image: np.ndarray,
|
||||
points: np.ndarray,
|
||||
point_labels: np.ndarray = None
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
使用点提示进行 SAM2 预测
|
||||
|
||||
|
||||
Args:
|
||||
predictor: HFSam2Predictor 实例
|
||||
predictor: SAM2ImagePredictor 实例
|
||||
image: RGB 图像 (H, W, 3)
|
||||
points: 点坐标 (N, 2),格式为 [x, y]
|
||||
point_labels: 点标签 (N,),1 表示正样本,0 表示负样本
|
||||
|
||||
|
||||
Returns:
|
||||
mask_pred: 预测掩码 (H, W)
|
||||
"""
|
||||
# 设置图像
|
||||
predictor.set_image(image)
|
||||
|
||||
|
||||
# 如果没有点,返回空掩码
|
||||
if len(points) == 0:
|
||||
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||
|
||||
|
||||
# 默认所有点都是正样本
|
||||
if point_labels is None:
|
||||
point_labels = np.ones(len(points), dtype=np.int32)
|
||||
|
||||
|
||||
# 使用点提示预测
|
||||
masks, scores, logits = predictor.predict(
|
||||
point_coords=points,
|
||||
point_labels=point_labels,
|
||||
multimask_output=False,
|
||||
)
|
||||
|
||||
|
||||
# 取第一个掩码(因为 multimask_output=False)
|
||||
mask_pred = masks[0] # shape: (H, W)
|
||||
|
||||
|
||||
# 转换为 0-255
|
||||
mask_pred = (mask_pred * 255).astype(np.uint8)
|
||||
|
||||
|
||||
return mask_pred
|
||||
|
||||
|
||||
def process_test_set(
|
||||
data_root: str,
|
||||
test_file: str,
|
||||
predictor: HFSam2Predictor,
|
||||
predictor: SAM2ImagePredictor,
|
||||
output_dir: str,
|
||||
num_points: int = 5,
|
||||
per_component: bool = False
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
处理整个测试集
|
||||
|
||||
|
||||
Args:
|
||||
data_root: 数据集根目录
|
||||
test_file: 测试集文件路径 (test.txt)
|
||||
predictor: HFSam2Predictor 实例
|
||||
predictor: SAM2ImagePredictor 实例
|
||||
output_dir: 输出目录
|
||||
num_points: 采样点数量
|
||||
per_component: 是否为每个连通域独立采样
|
||||
|
||||
|
||||
Returns:
|
||||
results: 包含每个样本信息的列表
|
||||
"""
|
||||
@ -198,27 +199,27 @@ def process_test_set(
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
pred_dir = os.path.join(output_dir, "predictions")
|
||||
os.makedirs(pred_dir, exist_ok=True)
|
||||
|
||||
|
||||
# 读取测试集文件
|
||||
with open(test_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
|
||||
results = []
|
||||
|
||||
|
||||
print(f"开始处理 {len(lines)} 张测试图像...")
|
||||
print(f"采样策略: {'每连通域' if per_component else '全局'} {num_points} 个点")
|
||||
|
||||
|
||||
for line in tqdm(lines, desc="处理测试集"):
|
||||
parts = line.strip().split()
|
||||
if len(parts) != 2:
|
||||
continue
|
||||
|
||||
|
||||
img_rel_path, mask_rel_path = parts
|
||||
|
||||
|
||||
# 构建完整路径
|
||||
img_path = os.path.join(data_root, img_rel_path)
|
||||
mask_path = os.path.join(data_root, mask_rel_path)
|
||||
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(img_path):
|
||||
print(f"警告: 图像不存在 {img_path}")
|
||||
@ -226,11 +227,11 @@ def process_test_set(
|
||||
if not os.path.exists(mask_path):
|
||||
print(f"警告: 掩码不存在 {mask_path}")
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
# 加载图像和掩码
|
||||
image, mask_gt = load_image_and_mask(img_path, mask_path)
|
||||
|
||||
|
||||
# 从 GT 掩码采样点
|
||||
if per_component:
|
||||
points, point_labels = sample_points_per_component(
|
||||
@ -239,18 +240,18 @@ def process_test_set(
|
||||
else:
|
||||
points = sample_points_on_skeleton(mask_gt, num_points=num_points)
|
||||
point_labels = np.ones(len(points), dtype=np.int32)
|
||||
|
||||
|
||||
# 使用 SAM2 预测
|
||||
with torch.inference_mode():
|
||||
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||
mask_pred = predict_with_point_prompt(
|
||||
predictor, image, points, point_labels
|
||||
)
|
||||
|
||||
|
||||
# 保存预测掩码
|
||||
img_name = Path(img_rel_path).stem
|
||||
pred_path = os.path.join(pred_dir, f"{img_name}_pred.png")
|
||||
cv2.imwrite(pred_path, mask_pred)
|
||||
|
||||
|
||||
# 记录结果
|
||||
results.append({
|
||||
"image_path": img_rel_path,
|
||||
@ -259,60 +260,62 @@ def process_test_set(
|
||||
"num_points": len(points),
|
||||
"image_shape": image.shape[:2],
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理失败 {img_path}: {str(e)}")
|
||||
continue
|
||||
|
||||
|
||||
# 保存结果信息
|
||||
results_file = os.path.join(output_dir, "results_info.json")
|
||||
with open(results_file, 'w') as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
|
||||
print(f"\n处理完成!共处理 {len(results)} 张图像")
|
||||
print(f"预测掩码保存在: {pred_dir}")
|
||||
print(f"结果信息保存在: {results_file}")
|
||||
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||
|
||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - Crack500 数据集评估")
|
||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
|
||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID")
|
||||
parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录")
|
||||
parser.add_argument("--checkpoint", type=str, default="./sam2/checkpoints/sam2.1_hiera_small.pt", help="模型检查点")
|
||||
parser.add_argument("--model_cfg", type=str, default="sam2.1_hiera_s.yaml", help="模型配置")
|
||||
parser.add_argument("--output_dir", type=str, default="./results/point_prompt", help="输出目录")
|
||||
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
|
||||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
print("=" * 60)
|
||||
print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
||||
print("SAM2 点提示方式 - Crack500 数据集评估")
|
||||
print("=" * 60)
|
||||
print(f"数据集根目录: {args.data_root}")
|
||||
print(f"测试集文件: {args.test_file}")
|
||||
print(f"模型: {args.model_id}")
|
||||
print(f"模型检查点: {args.checkpoint}")
|
||||
print(f"模型配置: {args.model_cfg}")
|
||||
print(f"采样点数量: {args.num_points}")
|
||||
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||||
print(f"输出目录: {args.output_dir}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 检查 CUDA 是否可用
|
||||
if not torch.cuda.is_available():
|
||||
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||
else:
|
||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||
|
||||
# 构建 SAM2 predictor
|
||||
|
||||
# 构建 SAM2 模型
|
||||
print("\n加载 SAM2 模型...")
|
||||
from .hf_sam2_predictor import build_hf_sam2_predictor
|
||||
predictor = build_hf_sam2_predictor(model_id=args.model_id)
|
||||
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
|
||||
predictor = SAM2ImagePredictor(sam2_model)
|
||||
print("模型加载完成!")
|
||||
|
||||
|
||||
# 处理测试集
|
||||
results = process_test_set(
|
||||
data_root=args.data_root,
|
||||
@ -322,7 +325,7 @@ def main():
|
||||
num_points=args.num_points,
|
||||
per_component=args.per_component
|
||||
)
|
||||
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("处理完成!接下来请运行评估脚本计算指标。")
|
||||
print("=" * 60)
|
||||
|
||||
@ -1,8 +0,0 @@
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
from .pipeline import TaskRunner
|
||||
from .registry import TaskRegistry
|
||||
|
||||
# ensure built-in tasks are registered
|
||||
from . import examples # noqa: F401
|
||||
|
||||
__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"]
|
||||
@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
|
||||
TaskStepKind = Literal[
|
||||
"train",
|
||||
"evaluate",
|
||||
"visualize",
|
||||
"bbox_inference",
|
||||
"point_inference",
|
||||
"legacy_evaluation",
|
||||
"legacy_visualization",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskStepConfig:
|
||||
kind: TaskStepKind
|
||||
dataset_split: Optional[str] = None
|
||||
dataset_split_file: Optional[str] = None
|
||||
limit: Optional[int] = None
|
||||
eval_split: Optional[str] = None
|
||||
eval_split_file: Optional[str] = None
|
||||
params: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskConfig:
|
||||
name: str
|
||||
description: str
|
||||
project_config_name: str
|
||||
model_key: str = "sam2"
|
||||
steps: List[TaskStepConfig] = field(default_factory=list)
|
||||
dataset_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
model_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
training_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
evaluation_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
visualization_overrides: Dict[str, Any] = field(default_factory=dict)
|
||||
@ -1,34 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
from .registry import TaskRegistry
|
||||
|
||||
TaskRegistry.register(
|
||||
TaskConfig(
|
||||
name="sam2_crack500_eval",
|
||||
description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.",
|
||||
project_config_name="sam2_bbox_prompt",
|
||||
steps=[
|
||||
TaskStepConfig(kind="evaluate", dataset_split="test"),
|
||||
TaskStepConfig(kind="visualize", dataset_split="test", limit=20),
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
TaskRegistry.register(
|
||||
TaskConfig(
|
||||
name="sam2_crack500_train_eval",
|
||||
description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.",
|
||||
project_config_name="sam2_bbox_prompt",
|
||||
steps=[
|
||||
TaskStepConfig(
|
||||
kind="train",
|
||||
dataset_split="train",
|
||||
eval_split="val",
|
||||
params={"num_train_epochs": 2},
|
||||
),
|
||||
TaskStepConfig(kind="evaluate", dataset_split="val", limit=32),
|
||||
TaskStepConfig(kind="visualize", dataset_split="val", limit=16),
|
||||
],
|
||||
)
|
||||
)
|
||||
@ -1,40 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tomllib
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
|
||||
|
||||
def load_task_from_toml(path: str | Path) -> TaskConfig:
|
||||
"""
|
||||
Load a TaskConfig from a TOML file.
|
||||
"""
|
||||
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
|
||||
task_data = data.get("task", {})
|
||||
steps_data: List[Dict[str, Any]] = data.get("steps", [])
|
||||
steps = [
|
||||
TaskStepConfig(
|
||||
kind=step["kind"],
|
||||
dataset_split=step.get("dataset_split"),
|
||||
dataset_split_file=step.get("dataset_split_file"),
|
||||
limit=step.get("limit"),
|
||||
eval_split=step.get("eval_split"),
|
||||
eval_split_file=step.get("eval_split_file"),
|
||||
params=step.get("params", {}),
|
||||
)
|
||||
for step in steps_data
|
||||
]
|
||||
return TaskConfig(
|
||||
name=task_data["name"],
|
||||
description=task_data.get("description", ""),
|
||||
project_config_name=task_data["project_config_name"],
|
||||
model_key=task_data.get("model_key", "sam2"),
|
||||
steps=steps,
|
||||
dataset_overrides=task_data.get("dataset_overrides", {}),
|
||||
model_overrides=task_data.get("model_overrides", {}),
|
||||
training_overrides=task_data.get("training_overrides", {}),
|
||||
evaluation_overrides=task_data.get("evaluation_overrides", {}),
|
||||
visualization_overrides=task_data.get("visualization_overrides", {}),
|
||||
)
|
||||
@ -1,264 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import fields, replace
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from ..bbox_prompt import process_test_set as bbox_process_test_set
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..evaluation import PipelineEvaluator
|
||||
from ..evaluation.utils import extract_mask_from_pipeline_output
|
||||
from ..hf_sam2_predictor import build_hf_sam2_predictor
|
||||
from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set
|
||||
from ..legacy_visualization import (
|
||||
create_metrics_distribution_plot,
|
||||
visualize_test_set as legacy_visualize_test_set,
|
||||
)
|
||||
from ..model import FineTuningTrainer, ModelRegistry
|
||||
from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig
|
||||
from ..point_prompt import process_test_set as point_process_test_set
|
||||
from ..visualization import OverlayGenerator
|
||||
from .config import TaskConfig, TaskStepConfig
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _replace_dataclass(instance, updates: Dict[str, Any]):
|
||||
if not updates:
|
||||
return instance
|
||||
valid_fields = {f.name for f in fields(type(instance))}
|
||||
filtered = {k: v for k, v in updates.items() if k in valid_fields}
|
||||
if not filtered:
|
||||
return instance
|
||||
return replace(instance, **filtered)
|
||||
|
||||
|
||||
def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
||||
updates: Dict[str, Any] = {"split": split}
|
||||
if split_file:
|
||||
updates["split_file"] = split_file
|
||||
return replace(config, **updates)
|
||||
|
||||
|
||||
class TaskRunner:
|
||||
"""
|
||||
Sequentially executes a series of task steps (train/eval/visualize).
|
||||
"""
|
||||
|
||||
def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None:
|
||||
self.task_config = task_config
|
||||
base_project = project_config or ConfigRegistry.get(task_config.project_config_name)
|
||||
if project_config is None:
|
||||
base_project = self._apply_project_overrides(base_project)
|
||||
self.project_config = base_project
|
||||
self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model)
|
||||
|
||||
def run(self) -> None:
|
||||
LOGGER.info("Starting task '%s'", self.task_config.name)
|
||||
for idx, step in enumerate(self.task_config.steps, start=1):
|
||||
LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind)
|
||||
if step.kind == "train":
|
||||
self._run_train(step)
|
||||
elif step.kind == "evaluate":
|
||||
self._run_evaluate(step)
|
||||
elif step.kind == "visualize":
|
||||
self._run_visualize(step)
|
||||
elif step.kind == "bbox_inference":
|
||||
self._run_bbox_inference(step)
|
||||
elif step.kind == "point_inference":
|
||||
self._run_point_inference(step)
|
||||
elif step.kind == "legacy_evaluation":
|
||||
self._run_legacy_evaluation(step)
|
||||
elif step.kind == "legacy_visualization":
|
||||
self._run_legacy_visualization(step)
|
||||
else:
|
||||
raise ValueError(f"Unknown task step: {step.kind}")
|
||||
|
||||
def _build_dataset(self, split: str, split_file: Optional[str]):
|
||||
dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file)
|
||||
return DatasetRegistry.create(
|
||||
dataset_cfg.name,
|
||||
config=dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
|
||||
def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig:
|
||||
dataset_cfg = config.dataset
|
||||
if self.task_config.dataset_overrides:
|
||||
dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides)
|
||||
evaluation_cfg = config.evaluation
|
||||
if self.task_config.evaluation_overrides:
|
||||
evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides)
|
||||
visualization_cfg = config.visualization
|
||||
if self.task_config.visualization_overrides:
|
||||
visualization_cfg = self._apply_simple_overrides(
|
||||
visualization_cfg, self.task_config.visualization_overrides
|
||||
)
|
||||
model_cfg = config.model
|
||||
if self.task_config.model_overrides:
|
||||
model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides)
|
||||
training_cfg = config.training
|
||||
if self.task_config.training_overrides:
|
||||
training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides)
|
||||
return replace(
|
||||
config,
|
||||
dataset=dataset_cfg,
|
||||
model=model_cfg,
|
||||
training=training_cfg,
|
||||
evaluation=evaluation_cfg,
|
||||
visualization=visualization_cfg,
|
||||
)
|
||||
|
||||
def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig:
|
||||
overrides = dict(overrides)
|
||||
extra_updates = overrides.pop("extra_params", {})
|
||||
merged_extra = dict(dataset_cfg.extra_params or {})
|
||||
merged_extra.update(extra_updates)
|
||||
return replace(dataset_cfg, **overrides, extra_params=merged_extra)
|
||||
|
||||
def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]):
|
||||
overrides = dict(overrides)
|
||||
return replace(cfg, **overrides)
|
||||
|
||||
def _default_data_root(self) -> str:
|
||||
return self.project_config.dataset.data_root
|
||||
|
||||
def _default_test_file(self) -> str:
|
||||
dataset_cfg = self.project_config.dataset
|
||||
candidate = dataset_cfg.split_file or "test.txt"
|
||||
candidate_path = Path(candidate)
|
||||
if candidate_path.is_absolute():
|
||||
return str(candidate_path)
|
||||
return str(Path(dataset_cfg.data_root) / candidate)
|
||||
|
||||
def _default_output_dir(self) -> str:
|
||||
return self.project_config.evaluation.output_dir
|
||||
|
||||
def _run_train(self, step: TaskStepConfig) -> None:
|
||||
train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||
eval_dataset = None
|
||||
if step.eval_split:
|
||||
eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file)
|
||||
trainer_builder = FineTuningTrainer(
|
||||
adapter=self.adapter,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
training_config=_replace_dataclass(
|
||||
self.project_config.training,
|
||||
dict(step.params),
|
||||
),
|
||||
)
|
||||
artifacts = trainer_builder.build()
|
||||
train_result = artifacts.trainer.train()
|
||||
LOGGER.info("Training result: %s", train_result)
|
||||
artifacts.trainer.save_model(self.project_config.training.output_dir)
|
||||
if eval_dataset:
|
||||
metrics = artifacts.trainer.evaluate()
|
||||
LOGGER.info("Evaluation metrics: %s", metrics)
|
||||
|
||||
def _run_evaluate(self, step: TaskStepConfig) -> None:
|
||||
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||
evaluation_cfg = _replace_dataclass(
|
||||
self.project_config.evaluation,
|
||||
{**dict(step.params), "max_samples": step.limit},
|
||||
)
|
||||
evaluator = PipelineEvaluator(
|
||||
dataset=dataset,
|
||||
adapter=self.adapter,
|
||||
config=evaluation_cfg,
|
||||
)
|
||||
summary = evaluator.run()
|
||||
LOGGER.info("Evaluation summary: %s", summary)
|
||||
|
||||
def _run_visualize(self, step: TaskStepConfig) -> None:
|
||||
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
||||
vis_config = _replace_dataclass(
|
||||
self.project_config.visualization,
|
||||
{**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples},
|
||||
)
|
||||
overlay = OverlayGenerator(vis_config)
|
||||
pipe = self.adapter.build_pipeline()
|
||||
limit = min(vis_config.num_samples, len(dataset))
|
||||
for idx in range(limit):
|
||||
sample = dataset[idx]
|
||||
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
||||
pred_mask = extract_mask_from_pipeline_output(preds)
|
||||
mask = sample.get("labels", {}).get("mask")
|
||||
overlay.visualize_sample(
|
||||
image=sample["pixel_values"],
|
||||
prediction=pred_mask,
|
||||
mask=mask,
|
||||
metadata=sample.get("metadata"),
|
||||
)
|
||||
LOGGER.info("Saved overlays to %s", vis_config.save_dir)
|
||||
|
||||
def _run_bbox_inference(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05))
|
||||
output_dir = params.get("output_dir", self._default_output_dir())
|
||||
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
||||
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
||||
bbox_process_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
predictor=predictor,
|
||||
output_dir=output_dir,
|
||||
expand_ratio=expand_ratio,
|
||||
)
|
||||
|
||||
def _run_point_inference(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
num_points = params.get("num_points", 5)
|
||||
per_component = params.get("per_component", False)
|
||||
output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf"
|
||||
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
||||
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
||||
point_process_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
predictor=predictor,
|
||||
output_dir=output_dir,
|
||||
num_points=num_points,
|
||||
per_component=per_component,
|
||||
)
|
||||
|
||||
def _run_legacy_evaluation(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
output_dir = params.get("output_dir", self._default_output_dir())
|
||||
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
||||
compute_skeleton = params.get("compute_skeleton", True)
|
||||
legacy_evaluate_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=output_dir,
|
||||
compute_skeleton=compute_skeleton,
|
||||
)
|
||||
|
||||
def _run_legacy_visualization(self, step: TaskStepConfig) -> None:
|
||||
params = dict(step.params)
|
||||
data_root = params.get("data_root", self._default_data_root())
|
||||
test_file = params.get("test_file", self._default_test_file())
|
||||
output_dir = params.get("output_dir", self._default_output_dir())
|
||||
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
||||
num_samples = params.get("num_samples", 20)
|
||||
save_all = params.get("save_all", False)
|
||||
results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv"))
|
||||
legacy_visualize_test_set(
|
||||
data_root=data_root,
|
||||
test_file=test_file,
|
||||
pred_dir=pred_dir,
|
||||
output_dir=output_dir,
|
||||
results_csv=results_csv if Path(results_csv).exists() else None,
|
||||
num_samples=num_samples,
|
||||
save_all=save_all,
|
||||
)
|
||||
if params.get("create_metrics_plot", True):
|
||||
create_metrics_distribution_plot(results_csv, output_dir)
|
||||
@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from .config import TaskConfig
|
||||
|
||||
|
||||
class TaskRegistry:
|
||||
"""
|
||||
Holds named task configs for reuse.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, TaskConfig] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task: TaskConfig) -> TaskConfig:
|
||||
cls._registry[task.name] = task
|
||||
return task
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> TaskConfig:
|
||||
if name not in cls._registry:
|
||||
raise KeyError(f"Task '{name}' is not registered.")
|
||||
return cls._registry[name]
|
||||
|
||||
@classmethod
|
||||
def available(cls) -> Dict[str, TaskConfig]:
|
||||
return dict(cls._registry)
|
||||
@ -1,44 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from .config import TaskConfig
|
||||
from .io import load_task_from_toml
|
||||
from .pipeline import TaskRunner
|
||||
from .registry import TaskRegistry
|
||||
|
||||
# ensure built-in tasks are registered when CLI runs
|
||||
from . import examples # noqa: F401
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TaskCLIArguments:
|
||||
task_name: Optional[str] = None
|
||||
task_file: Optional[str] = None
|
||||
|
||||
|
||||
def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig:
|
||||
if not cli_args.task_name and not cli_args.task_file:
|
||||
raise ValueError("Provide either --task_name or --task_file.")
|
||||
if cli_args.task_file:
|
||||
return load_task_from_toml(cli_args.task_file)
|
||||
return TaskRegistry.get(cli_args.task_name)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(TaskCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
task = resolve_task(cli_args)
|
||||
runner = TaskRunner(task)
|
||||
runner.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
@ -81,17 +81,17 @@ def create_comparison_figure(
|
||||
|
||||
# 原始图像
|
||||
axes[0, 0].imshow(image)
|
||||
axes[0, 0].set_title("Original Image", fontsize=16)
|
||||
axes[0, 0].set_title("Original Image", fontsize=12)
|
||||
axes[0, 0].axis('off')
|
||||
|
||||
# GT 掩码
|
||||
axes[0, 1].imshow(mask_gt, cmap='gray')
|
||||
axes[0, 1].set_title("Ground Truth", fontsize=16)
|
||||
axes[0, 1].set_title("Ground Truth", fontsize=12)
|
||||
axes[0, 1].axis('off')
|
||||
|
||||
# 预测掩码
|
||||
axes[1, 0].imshow(mask_pred, cmap='gray')
|
||||
axes[1, 0].set_title("Prediction", fontsize=16)
|
||||
axes[1, 0].set_title("Prediction", fontsize=12)
|
||||
axes[1, 0].axis('off')
|
||||
|
||||
# 叠加可视化
|
||||
@ -110,16 +110,16 @@ def create_comparison_figure(
|
||||
axes[1, 1].text(
|
||||
0.02, 0.98, legend_text,
|
||||
transform=axes[1, 1].transAxes,
|
||||
fontsize=16,
|
||||
fontsize=10,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
|
||||
)
|
||||
axes[1, 1].set_title("Overlay Visualization", fontsize=16)
|
||||
axes[1, 1].set_title("Overlay Visualization", fontsize=12)
|
||||
axes[1, 1].axis('off')
|
||||
|
||||
# # 设置总标题
|
||||
# if title:
|
||||
# fig.suptitle(title, fontsize=16, fontweight='bold')
|
||||
# 设置总标题
|
||||
if title:
|
||||
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
@ -1,4 +0,0 @@
|
||||
from .gallery import build_gallery
|
||||
from .overlay import OverlayGenerator
|
||||
|
||||
__all__ = ["OverlayGenerator", "build_gallery"]
|
||||
@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path:
|
||||
"""
|
||||
Simple grid composer that stitches overlay PNGs into a gallery.
|
||||
"""
|
||||
image_paths = list(image_paths)
|
||||
if not image_paths:
|
||||
raise ValueError("No images provided for gallery.")
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
images = [Image.open(path).convert("RGB") for path in image_paths]
|
||||
widths, heights = zip(*(img.size for img in images))
|
||||
cell_w = max(widths)
|
||||
cell_h = max(heights)
|
||||
rows = (len(images) + columns - 1) // columns
|
||||
canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0))
|
||||
for idx, img in enumerate(images):
|
||||
row = idx // columns
|
||||
col = idx % columns
|
||||
canvas.paste(img, (col * cell_w, row * cell_h))
|
||||
canvas.save(output_path)
|
||||
return output_path
|
||||
@ -1,62 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from ..model_configuration import VisualizationConfig
|
||||
|
||||
|
||||
class OverlayGenerator:
|
||||
"""
|
||||
Turns model predictions into side-by-side overlays for quick inspection.
|
||||
"""
|
||||
|
||||
def __init__(self, config: VisualizationConfig) -> None:
|
||||
self.config = config
|
||||
Path(self.config.save_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def visualize_sample(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prediction: np.ndarray,
|
||||
mask: Optional[np.ndarray],
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> Path:
|
||||
overlay = self._compose_overlay(image, prediction, mask)
|
||||
filename = (
|
||||
metadata.get("image_name", "sample")
|
||||
if metadata
|
||||
else "sample"
|
||||
)
|
||||
target = Path(self.config.save_dir) / f"{filename}_overlay.png"
|
||||
Image.fromarray(overlay).save(target)
|
||||
return target
|
||||
|
||||
def _compose_overlay(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
prediction: np.ndarray,
|
||||
mask: Optional[np.ndarray],
|
||||
) -> np.ndarray:
|
||||
vis = image.copy()
|
||||
pred_mask = self._normalize(prediction)
|
||||
color = np.zeros_like(vis)
|
||||
color[..., 0] = pred_mask
|
||||
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
||||
if mask is not None:
|
||||
gt = self._normalize(mask)
|
||||
color = np.zeros_like(vis)
|
||||
color[..., 1] = gt
|
||||
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
||||
return vis
|
||||
|
||||
def _normalize(self, array: np.ndarray) -> np.ndarray:
|
||||
normalized = array.astype(np.float32)
|
||||
normalized -= normalized.min()
|
||||
denom = normalized.max() or 1.0
|
||||
normalized = normalized / denom
|
||||
normalized = (normalized * 255).astype(np.uint8)
|
||||
return normalized
|
||||
@ -1,58 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, replace
|
||||
from typing import Optional
|
||||
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
from ..dataset import DatasetRegistry
|
||||
from ..evaluation.utils import extract_mask_from_pipeline_output
|
||||
from ..model import ModelRegistry
|
||||
from ..model_configuration import ConfigRegistry
|
||||
from .overlay import OverlayGenerator
|
||||
|
||||
LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisualizationCLIArguments:
|
||||
config_name: str = "sam2_bbox_prompt"
|
||||
model_key: str = "sam2"
|
||||
split: str = "test"
|
||||
split_file: Optional[str] = None
|
||||
num_samples: int = 20
|
||||
device: Optional[str] = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = HfArgumentParser(VisualizationCLIArguments)
|
||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
||||
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
||||
dataset = DatasetRegistry.create(
|
||||
dataset_cfg.name,
|
||||
config=dataset_cfg,
|
||||
return_hf_dict=True,
|
||||
)
|
||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
||||
overlay = OverlayGenerator(project_config.visualization)
|
||||
pipe = adapter.build_pipeline(device=cli_args.device)
|
||||
limit = min(cli_args.num_samples, len(dataset))
|
||||
for idx in range(limit):
|
||||
sample = dataset[idx]
|
||||
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
||||
pred_mask = extract_mask_from_pipeline_output(preds)
|
||||
mask = sample.get("labels", {}).get("mask")
|
||||
overlay.visualize_sample(
|
||||
image=sample["pixel_values"],
|
||||
prediction=pred_mask,
|
||||
mask=mask,
|
||||
metadata=sample.get("metadata"),
|
||||
)
|
||||
LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
||||
@ -1,34 +0,0 @@
|
||||
[task]
|
||||
name = "bbox_cli_template"
|
||||
description = "Run legacy bbox-prompt inference + evaluation + visualization"
|
||||
project_config_name = "sam2_bbox_prompt"
|
||||
|
||||
[[steps]]
|
||||
kind = "bbox_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
output_dir = "./results/bbox_prompt"
|
||||
expand_ratio = 0.05
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/bbox_prompt"
|
||||
pred_dir = "./results/bbox_prompt/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/bbox_prompt"
|
||||
pred_dir = "./results/bbox_prompt/predictions"
|
||||
results_csv = "./results/bbox_prompt/evaluation_results.csv"
|
||||
num_samples = 20
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
@ -1,100 +0,0 @@
|
||||
[task]
|
||||
name = "point_cli_template"
|
||||
description = "Run legacy point-prompt inference/eval/visualization for multiple configs"
|
||||
project_config_name = "sam2_bbox_prompt"
|
||||
|
||||
# 1 point config
|
||||
[[steps]]
|
||||
kind = "point_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
num_points = 1
|
||||
per_component = false
|
||||
output_dir = "./results/point_prompt_1pts_hf"
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_1pts_hf"
|
||||
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_1pts_hf"
|
||||
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
||||
results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv"
|
||||
num_samples = 10
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
|
||||
# 3 point config
|
||||
[[steps]]
|
||||
kind = "point_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
num_points = 3
|
||||
per_component = false
|
||||
output_dir = "./results/point_prompt_3pts_hf"
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_3pts_hf"
|
||||
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_3pts_hf"
|
||||
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
||||
results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv"
|
||||
num_samples = 10
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
|
||||
# 5 point config
|
||||
[[steps]]
|
||||
kind = "point_inference"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
model_id = "facebook/sam2-hiera-small"
|
||||
num_points = 5
|
||||
per_component = false
|
||||
output_dir = "./results/point_prompt_5pts_hf"
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_evaluation"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_5pts_hf"
|
||||
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
||||
compute_skeleton = true
|
||||
|
||||
[[steps]]
|
||||
kind = "legacy_visualization"
|
||||
[steps.params]
|
||||
data_root = "./crack500"
|
||||
test_file = "./crack500/test.txt"
|
||||
output_dir = "./results/point_prompt_5pts_hf"
|
||||
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
||||
results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv"
|
||||
num_samples = 10
|
||||
save_all = false
|
||||
create_metrics_plot = true
|
||||
Loading…
x
Reference in New Issue
Block a user