diff --git a/configs/preprocesser.json b/configs/preprocesser.json deleted file mode 100644 index 43c28ff..0000000 --- a/configs/preprocesser.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - "crop_size": null, - "data_format": "channels_first", - "default_to_square": true, - "device": null, - "disable_grouping": null, - "do_center_crop": null, - "do_convert_rgb": true, - "do_normalize": false, - "do_rescale": false, - "do_resize": false, - "image_mean": [ - 0.485, - 0.456, - 0.406 - ], - "image_processor_type": "Sam2ImageProcessorFast", - "image_std": [ - 0.229, - 0.224, - 0.225 - ], - "input_data_format": null, - "mask_size": { - "height": 256, - "width": 256 - }, - "processor_class": "Sam2VideoProcessor", - "resample": 2, - "rescale_factor": 0.00392156862745098, - "return_tensors": null, - "size": { - "longest_edge": 1024 - } -} diff --git a/pixi.lock b/pixi.lock index fe9b1ed..c8ef520 100644 --- a/pixi.lock +++ b/pixi.lock @@ -34,6 +34,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda + - pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz - pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl @@ -48,8 +49,10 @@ environments: - pypi: https://mirror.nju.edu.cn/pypi/web/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl + - pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl + - pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz - pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl @@ -81,6 +84,7 @@ environments: - pypi: https://mirror.nju.edu.cn/pypi/web/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl + - pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl @@ -88,6 +92,7 @@ environments: - pypi: https://mirror.nju.edu.cn/pypi/web/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/4f/87/424511bdcd02c8d7acf9f65caa09f291a519b16bd83c3fb3374b3d4ae951/pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl + - pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl @@ -122,6 +127,7 @@ environments: - pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl - pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl + - pypi: /home/dustella/projects/sam2 packages: - conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2 sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726 @@ -144,6 +150,12 @@ packages: purls: [] size: 23621 timestamp: 1650670423406 +- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz + name: antlr4-python3-runtime + version: 4.9.3 + sha256: f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b + requires_dist: + - typing ; python_full_version < '3.5' - pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl name: asttokens version: 3.0.1 @@ -540,6 +552,15 @@ packages: - types-tqdm ; extra == 'typing' - types-urllib3 ; extra == 'typing' requires_python: '>=3.8.0' +- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl + name: hydra-core + version: 1.3.2 + sha256: fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b + requires_dist: + - omegaconf>=2.2,<2.4 + - antlr4-python3-runtime==4.9.* + - packaging + - importlib-resources ; python_full_version < '3.9' - conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.1-h33c6efd_0.conda sha256: 7d6463d0be5092b2ae8f2fad34dc84de83eab8bd44cc0d4be8931881c973c48f md5: 518e9bbbc3e3486d6a4519192ba690f8 @@ -623,6 +644,17 @@ packages: - sphinx<6 ; extra == 'full' - tifffile ; extra == 'full' requires_python: '>=3.9' +- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz + name: iopath + version: 0.1.10 + sha256: 3311c16a4d9137223e20f141655759933e1eda24f8bff166af834af3c645ef01 + requires_dist: + - tqdm + - typing-extensions + - portalocker + - dataclasses ; python_full_version < '3.7' + - boto3 ; extra == 'aws' + requires_python: '>=3.6' - pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl name: ipykernel version: 7.1.0 @@ -1174,6 +1206,15 @@ packages: version: 12.8.90 sha256: 5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f requires_python: '>=3' +- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl + name: omegaconf + version: 2.3.0 + sha256: 7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b + requires_dist: + - antlr4-python3-runtime==4.9.* + - pyyaml>=5.1.0 + - dataclasses ; python_full_version == '3.6.*' + requires_python: '>=3.6' - pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl name: opencv-python version: 4.12.0.88 @@ -1355,6 +1396,25 @@ packages: - pytest>=8.4.2 ; extra == 'test' - mypy>=1.18.2 ; extra == 'type' requires_python: '>=3.10' +- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl + name: portalocker + version: 3.2.0 + sha256: 3cdc5f565312224bc570c49337bd21428bba0ef363bbcf58b9ef4a9f11779968 + requires_dist: + - pywin32>=226 ; sys_platform == 'win32' + - portalocker[tests] ; extra == 'docs' + - coverage-conditional-plugin>=0.9.0 ; extra == 'tests' + - portalocker[redis] ; extra == 'tests' + - pytest-cov>=2.8.1 ; extra == 'tests' + - pytest-mypy>=0.8.0 ; extra == 'tests' + - pytest-rerunfailures>=15.0 ; extra == 'tests' + - pytest-timeout>=2.1.0 ; extra == 'tests' + - pytest>=5.4.1 ; extra == 'tests' + - sphinx>=6.0.0 ; extra == 'tests' + - types-pywin32>=310.0.0.20250429 ; extra == 'tests' + - types-redis ; extra == 'tests' + - redis ; extra == 'redis' + requires_python: '>=3.9' - pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl name: prompt-toolkit version: 3.0.52 @@ -1544,6 +1604,44 @@ packages: - safetensors[testing] ; extra == 'all' - safetensors[all] ; extra == 'dev' requires_python: '>=3.9' +- pypi: /home/dustella/projects/sam2 + name: sam-2 + version: '1.0' + sha256: f3cc0abfc266b6f57a457d26bdf52fcaa92567d8fa27ba6ff59ebe60bec0ac69 + requires_dist: + - torch>=2.5.1 + - torchvision>=0.20.1 + - numpy>=1.24.4 + - tqdm>=4.66.1 + - hydra-core>=1.3.2 + - iopath>=0.1.10 + - pillow>=9.4.0 + - matplotlib>=3.9.1 ; extra == 'notebooks' + - jupyter>=1.0.0 ; extra == 'notebooks' + - opencv-python>=4.7.0 ; extra == 'notebooks' + - eva-decord>=0.6.1 ; extra == 'notebooks' + - flask>=3.0.3 ; extra == 'interactive-demo' + - flask-cors>=5.0.0 ; extra == 'interactive-demo' + - av>=13.0.0 ; extra == 'interactive-demo' + - dataclasses-json>=0.6.7 ; extra == 'interactive-demo' + - eva-decord>=0.6.1 ; extra == 'interactive-demo' + - gunicorn>=23.0.0 ; extra == 'interactive-demo' + - imagesize>=1.4.1 ; extra == 'interactive-demo' + - pycocotools>=2.0.8 ; extra == 'interactive-demo' + - strawberry-graphql>=0.243.0 ; extra == 'interactive-demo' + - black==24.2.0 ; extra == 'dev' + - usort==1.0.2 ; extra == 'dev' + - ufmt==2.0.0b2 ; extra == 'dev' + - fvcore>=0.1.5.post20221221 ; extra == 'dev' + - pandas>=2.2.2 ; extra == 'dev' + - scikit-image>=0.24.0 ; extra == 'dev' + - tensorboard>=2.17.0 ; extra == 'dev' + - pycocotools>=2.0.8 ; extra == 'dev' + - tensordict>=0.6.0 ; extra == 'dev' + - opencv-python>=4.7.0 ; extra == 'dev' + - submitit>=1.5.1 ; extra == 'dev' + requires_python: '>=3.10.0' + editable: true - pypi: https://mirror.nju.edu.cn/pypi/web/packages/f4/a2/70401a107d6d7466d64b466927e6b96fcefa99d57494b972608e2f8be50f/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl name: scikit-image version: 0.26.0 diff --git a/pixi.toml b/pixi.toml index cf903ff..cdf9096 100644 --- a/pixi.toml +++ b/pixi.toml @@ -26,3 +26,4 @@ tqdm = ">=4.65.0" pandas = ">=2.0.0" transformers = ">=4.57.3, <5" ipykernel = ">=7.1.0, <8" +sam-2 = { path = "/home/dustella/projects/sam2", editable = true } diff --git a/run_bbox_evaluation.py b/run_bbox_evaluation.py index b098d6c..85b912a 100755 --- a/run_bbox_evaluation.py +++ b/run_bbox_evaluation.py @@ -1,136 +1,248 @@ #!/usr/bin/env python3 """ -SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本) +SAM2 边界框提示方式完整评估流程 +包括:推理 -> 评估 -> 可视化 """ +import os +import sys import argparse -import logging -from dataclasses import dataclass -from typing import List, Optional +import time +from pathlib import Path -from src.tasks.config import TaskConfig, TaskStepConfig -from src.tasks.io import load_task_from_toml -from src.tasks.pipeline import TaskRunner +# 添加 src 目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor +from evaluation import evaluate_test_set +from visualization import visualize_test_set, create_metrics_distribution_plot -@dataclass -class BBoxCLIArgs: - data_root: str - test_file: str - model_id: str - output_dir: str - expand_ratio: float - num_vis: int - vis_all: bool - skip_inference: bool - skip_evaluation: bool - skip_visualization: bool - config_name: str - task_file: Optional[str] - - -def parse_args() -> BBoxCLIArgs: +def parse_args(): + """解析命令行参数""" parser = argparse.ArgumentParser( - description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估" + description="SAM2 边界框提示方式 - Crack500 数据集完整评估" ) - parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录") - parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径") - parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID") - parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录") - parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)") - parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量") - parser.add_argument("--vis_all", action="store_true", help="可视化所有样本") - parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤") - parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤") - parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤") + + # 数据集参数 parser.add_argument( - "--config_name", - type=str, - default="sam2_bbox_prompt", - help="ProjectConfig 名称(来自 ConfigRegistry)", + "--data_root", type=str, default="./crack500", + help="数据集根目录" ) parser.add_argument( - "--task_file", - type=str, - default=None, - help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)", + "--test_file", type=str, default="./crack500/test.txt", + help="测试集文件路径" ) - args = parser.parse_args() - return BBoxCLIArgs( - data_root=args.data_root, - test_file=args.test_file, - model_id=args.model_id, - output_dir=args.output_dir, - expand_ratio=args.expand_ratio, - num_vis=args.num_vis, - vis_all=args.vis_all, - skip_inference=args.skip_inference, - skip_evaluation=args.skip_evaluation, - skip_visualization=args.skip_visualization, - config_name=args.config_name, - task_file=args.task_file, + + # 模型参数 + parser.add_argument( + "--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt", + help="SAM2 模型检查点路径" ) - - -def build_cli_task(args: BBoxCLIArgs) -> TaskConfig: - steps: List[TaskStepConfig] = [] - common = { - "data_root": args.data_root, - "test_file": args.test_file, - "model_id": args.model_id, - "output_dir": args.output_dir, - } - if not args.skip_inference: - steps.append( - TaskStepConfig( - kind="bbox_inference", - params={**common, "expand_ratio": args.expand_ratio}, - ) - ) - if not args.skip_evaluation: - steps.append( - TaskStepConfig( - kind="legacy_evaluation", - params={ - **common, - "pred_dir": f"{args.output_dir}/predictions", - "compute_skeleton": True, - }, - ) - ) - if not args.skip_visualization: - steps.append( - TaskStepConfig( - kind="legacy_visualization", - params={ - **common, - "pred_dir": f"{args.output_dir}/predictions", - "results_csv": f"{args.output_dir}/evaluation_results.csv", - "num_samples": args.num_vis, - "save_all": args.vis_all, - "create_metrics_plot": True, - }, - ) - ) - return TaskConfig( - name="bbox_cli_run", - description="Legacy bbox prompt pipeline executed via TaskRunner", - project_config_name=args.config_name, - steps=steps, + parser.add_argument( + "--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml", + help="SAM2 模型配置文件" ) + + # 输出参数 + parser.add_argument( + "--output_dir", type=str, default="./results/bbox_prompt", + help="输出目录" + ) + + # 边界框参数 + parser.add_argument( + "--expand_ratio", type=float, default=0.05, + help="边界框扩展比例 (0.0-1.0)" + ) + + # 可视化参数 + parser.add_argument( + "--num_vis", type=int, default=20, + help="可视化样本数量" + ) + parser.add_argument( + "--vis_all", action="store_true", + help="可视化所有样本" + ) + + # 流程控制 + parser.add_argument( + "--skip_inference", action="store_true", + help="跳过推理步骤(使用已有预测结果)" + ) + parser.add_argument( + "--skip_evaluation", action="store_true", + help="跳过评估步骤" + ) + parser.add_argument( + "--skip_visualization", action="store_true", + help="跳过可视化步骤" + ) + + return parser.parse_args() -def main() -> None: - logging.basicConfig(level=logging.INFO) +def main(): + """主函数""" args = parse_args() - if args.task_file: - task = load_task_from_toml(args.task_file) + + print("=" * 80) + print("SAM2 边界框提示方式 - Crack500 数据集完整评估") + print("=" * 80) + print(f"数据集根目录: {args.data_root}") + print(f"测试集文件: {args.test_file}") + print(f"模型检查点: {args.checkpoint}") + print(f"模型配置: {args.model_cfg}") + print(f"边界框扩展比例: {args.expand_ratio * 100}%") + print(f"输出目录: {args.output_dir}") + print("=" * 80) + + # 检查必要文件 + if not os.path.exists(args.data_root): + print(f"\n错误: 数据集目录不存在 {args.data_root}") + return + + if not os.path.exists(args.test_file): + print(f"\n错误: 测试集文件不存在 {args.test_file}") + return + + # 创建输出目录 + os.makedirs(args.output_dir, exist_ok=True) + + # ========== 步骤 1: 推理 ========== + if not args.skip_inference: + print("\n" + "=" * 80) + print("步骤 1/3: 使用 SAM2 进行推理") + print("=" * 80) + + # 检查模型文件 + if not os.path.exists(args.checkpoint): + print(f"\n错误: 模型检查点不存在 {args.checkpoint}") + print("请先下载 SAM2 模型权重!") + print("运行: cd sam2/checkpoints && ./download_ckpts.sh") + return + + try: + import torch + + # 检查 CUDA + if not torch.cuda.is_available(): + print("警告: CUDA 不可用,将使用 CPU(速度会很慢)") + else: + print(f"使用 GPU: {torch.cuda.get_device_name(0)}") + + # 加载模型 + print("\n加载 SAM2 模型...") + start_time = time.time() + sam2_model = build_sam2(args.model_cfg, args.checkpoint) + predictor = SAM2ImagePredictor(sam2_model) + print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s") + + # 处理测试集 + print("\n开始推理...") + start_time = time.time() + results = process_test_set( + data_root=args.data_root, + test_file=args.test_file, + predictor=predictor, + output_dir=args.output_dir, + expand_ratio=args.expand_ratio + ) + print(f"推理完成!耗时: {time.time() - start_time:.2f}s") + print(f"成功处理 {len(results)} 张图像") + + except Exception as e: + print(f"\n推理过程出错: {str(e)}") + import traceback + traceback.print_exc() + return else: - task = build_cli_task(args) - if not task.steps: - raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.") - runner = TaskRunner(task) - runner.run() + print("\n跳过推理步骤(使用已有预测结果)") + + # ========== 步骤 2: 评估 ========== + if not args.skip_evaluation: + print("\n" + "=" * 80) + print("步骤 2/3: 评估预测结果") + print("=" * 80) + + pred_dir = os.path.join(args.output_dir, "predictions") + + if not os.path.exists(pred_dir): + print(f"\n错误: 预测目录不存在 {pred_dir}") + print("请先运行推理步骤!") + return + + try: + start_time = time.time() + df_results = evaluate_test_set( + data_root=args.data_root, + test_file=args.test_file, + pred_dir=pred_dir, + output_dir=args.output_dir, + compute_skeleton=True + ) + print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s") + + except Exception as e: + print(f"\n评估过程出错: {str(e)}") + import traceback + traceback.print_exc() + return + else: + print("\n跳过评估步骤") + + # ========== 步骤 3: 可视化 ========== + if not args.skip_visualization: + print("\n" + "=" * 80) + print("步骤 3/3: 生成可视化结果") + print("=" * 80) + + pred_dir = os.path.join(args.output_dir, "predictions") + results_csv = os.path.join(args.output_dir, "evaluation_results.csv") + + if not os.path.exists(pred_dir): + print(f"\n错误: 预测目录不存在 {pred_dir}") + return + + try: + start_time = time.time() + + # 可视化样本 + visualize_test_set( + data_root=args.data_root, + test_file=args.test_file, + pred_dir=pred_dir, + output_dir=args.output_dir, + results_csv=results_csv if os.path.exists(results_csv) else None, + num_samples=args.num_vis, + save_all=args.vis_all + ) + + # 创建指标分布图 + if os.path.exists(results_csv): + create_metrics_distribution_plot(results_csv, args.output_dir) + + print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s") + + except Exception as e: + print(f"\n可视化过程出错: {str(e)}") + import traceback + traceback.print_exc() + return + else: + print("\n跳过可视化步骤") + + # ========== 完成 ========== + print("\n" + "=" * 80) + print("所有步骤完成!") + print("=" * 80) + print(f"\n结果保存在: {args.output_dir}") + print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}") + print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}") + print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}") + print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}") + print("=" * 80) if __name__ == "__main__": diff --git a/run_point_evaluation.py b/run_point_evaluation.py index 7f4dd3b..b97d284 100755 --- a/run_point_evaluation.py +++ b/run_point_evaluation.py @@ -1,222 +1,272 @@ #!/usr/bin/env python3 """ -SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本) +SAM2 点提示方式完整评估流程 +支持 1, 3, 5 个点的对比实验 """ -import argparse -import logging import os -from dataclasses import dataclass +import sys +import argparse +import time from pathlib import Path -from typing import Dict, List, Optional -import pandas as pd +# 添加 src 目录到路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) -from src.tasks.config import TaskConfig, TaskStepConfig -from src.tasks.io import load_task_from_toml -from src.tasks.pipeline import TaskRunner +from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor +from evaluation import evaluate_test_set +from visualization import visualize_test_set, create_metrics_distribution_plot -@dataclass -class PointCLIArgs: - data_root: str - test_file: str - model_id: str - point_configs: List[int] - per_component: bool - num_vis: int - skip_inference: bool - skip_evaluation: bool - skip_visualization: bool - skip_comparison: bool - comparison_dir: str - config_name: str - task_file: Optional[str] - - -def parse_args() -> PointCLIArgs: - parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验") - parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录") - parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径") - parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID") - parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置") - parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点") - parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量") - parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤") - parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤") - parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤") - parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比") - parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录") - parser.add_argument( - "--config_name", - type=str, - default="sam2_bbox_prompt", - help="ProjectConfig 名称(来自 ConfigRegistry)", - ) - parser.add_argument( - "--task_file", - type=str, - default=None, - help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)", - ) - args = parser.parse_args() - return PointCLIArgs( - data_root=args.data_root, - test_file=args.test_file, - model_id=args.model_id, - point_configs=args.point_configs, - per_component=args.per_component, - num_vis=args.num_vis, - skip_inference=args.skip_inference, - skip_evaluation=args.skip_evaluation, - skip_visualization=args.skip_visualization, - skip_comparison=args.skip_comparison, - comparison_dir=args.comparison_dir, - config_name=args.config_name, - task_file=args.task_file, - ) - - -def default_output_dir(num_points: int, per_component: bool) -> str: +def run_single_experiment( + data_root: str, + test_file: str, + checkpoint: str, + model_cfg: str, + num_points: int, + per_component: bool = False +): + """运行单个点数的实验""" + + # 设置输出目录 if per_component: - return f"./results/point_prompt_{num_points}pts_per_comp_hf" - return f"./results/point_prompt_{num_points}pts_hf" - - -def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig: - steps: List[TaskStepConfig] = [] - common = { - "data_root": args.data_root, - "test_file": args.test_file, - "model_id": args.model_id, - "output_dir": output_dir, - } - if not args.skip_inference: - steps.append( - TaskStepConfig( - kind="point_inference", - params={ - **common, - "num_points": num_points, - "per_component": args.per_component, - }, - ) - ) - if not args.skip_evaluation: - steps.append( - TaskStepConfig( - kind="legacy_evaluation", - params={ - **common, - "pred_dir": f"{output_dir}/predictions", - "compute_skeleton": True, - }, - ) - ) - if not args.skip_visualization: - steps.append( - TaskStepConfig( - kind="legacy_visualization", - params={ - **common, - "pred_dir": f"{output_dir}/predictions", - "results_csv": f"{output_dir}/evaluation_results.csv", - "num_samples": args.num_vis, - "save_all": False, - "create_metrics_plot": True, - }, - ) - ) - return TaskConfig( - name=f"point_cli_{num_points}", - description=f"Legacy point prompt pipeline ({num_points} pts)", - project_config_name=args.config_name, - steps=steps, + output_dir = f"./results/point_prompt_{num_points}pts_per_comp" + else: + output_dir = f"./results/point_prompt_{num_points}pts" + + print("\n" + "=" * 80) + print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})") + print("=" * 80) + + # 加载模型 + print("\n加载 SAM2 模型...") + import torch + sam2_model = build_sam2(model_cfg, checkpoint) + predictor = SAM2ImagePredictor(sam2_model) + print("模型加载完成!") + + # 推理 + print(f"\n步骤 1/3: 推理 ({num_points} 个点)") + start_time = time.time() + results = process_test_set( + data_root=data_root, + test_file=test_file, + predictor=predictor, + output_dir=output_dir, + num_points=num_points, + per_component=per_component ) + print(f"推理完成!耗时: {time.time() - start_time:.2f}s") + + # 评估 + print(f"\n步骤 2/3: 评估") + pred_dir = os.path.join(output_dir, "predictions") + start_time = time.time() + df_results = evaluate_test_set( + data_root=data_root, + test_file=test_file, + pred_dir=pred_dir, + output_dir=output_dir, + compute_skeleton=True + ) + print(f"评估完成!耗时: {time.time() - start_time:.2f}s") + + # 可视化 + print(f"\n步骤 3/3: 可视化") + results_csv = os.path.join(output_dir, "evaluation_results.csv") + start_time = time.time() + visualize_test_set( + data_root=data_root, + test_file=test_file, + pred_dir=pred_dir, + output_dir=output_dir, + results_csv=results_csv, + num_samples=10, + save_all=False + ) + create_metrics_distribution_plot(results_csv, output_dir) + print(f"可视化完成!耗时: {time.time() - start_time:.2f}s") + + return df_results -def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]: - csv_path = Path(output_dir) / "evaluation_results.csv" - if not csv_path.exists(): - return None - return pd.read_csv(csv_path) - - -def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None: - if not results: - return - os.makedirs(output_dir, exist_ok=True) - summary_rows = [] - for num_points, df in results.items(): - summary_rows.append( - { - "num_points": num_points, - "iou_mean": df["iou"].mean(), - "iou_std": df["iou"].std(), - "dice_mean": df["dice"].mean(), - "dice_std": df["dice"].std(), - "f1_mean": df["f1_score"].mean(), - "f1_std": df["f1_score"].std(), - "precision_mean": df["precision"].mean(), - "recall_mean": df["recall"].mean(), - } - ) - df_summary = pd.DataFrame(summary_rows).sort_values("num_points") - summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv" - summary_path.parent.mkdir(parents=True, exist_ok=True) - df_summary.to_csv(summary_path, index=False) - +def compare_results(results_dict: dict, output_dir: str = "./results"): + """对比不同点数的结果""" + import pandas as pd import matplotlib.pyplot as plt - - metrics_to_plot = [ - ("iou_mean", "iou_std", "IoU"), - ("dice_mean", "dice_std", "Dice"), - ("f1_mean", "f1_std", "F1-Score"), - ] + + print("\n" + "=" * 80) + print("对比不同点数的性能") + print("=" * 80) + + # 收集所有结果 + summary = [] + for num_points, df in results_dict.items(): + metrics = { + 'num_points': num_points, + 'iou_mean': df['iou'].mean(), + 'iou_std': df['iou'].std(), + 'dice_mean': df['dice'].mean(), + 'dice_std': df['dice'].std(), + 'f1_mean': df['f1_score'].mean(), + 'f1_std': df['f1_score'].std(), + 'precision_mean': df['precision'].mean(), + 'recall_mean': df['recall'].mean(), + } + summary.append(metrics) + + df_summary = pd.DataFrame(summary) + + # 打印对比表格 + print("\n性能对比:") + print(df_summary.to_string(index=False)) + + # 保存对比结果 + comparison_dir = os.path.join(output_dir, "point_comparison") + os.makedirs(comparison_dir, exist_ok=True) + + csv_path = os.path.join(comparison_dir, "comparison_summary.csv") + df_summary.to_csv(csv_path, index=False) + print(f"\n对比结果已保存到: {csv_path}") + + # 绘制对比图 fig, axes = plt.subplots(1, 3, figsize=(15, 5)) - xs = df_summary["num_points"].tolist() - for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot): - ax.errorbar( - xs, - df_summary[mean_col], - yerr=df_summary[std_col], - marker="o", - capsize=5, - linewidth=2, - markersize=8, - ) - ax.set_xlabel("Number of Points", fontsize=12) + + metrics_to_plot = [ + ('iou_mean', 'iou_std', 'IoU'), + ('dice_mean', 'dice_std', 'Dice'), + ('f1_mean', 'f1_std', 'F1-Score') + ] + + for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot): + ax = axes[idx] + x = df_summary['num_points'] + y = df_summary[mean_col] + yerr = df_summary[std_col] + + ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8) + ax.set_xlabel('Number of Points', fontsize=12) ax.set_ylabel(title, fontsize=12) - ax.set_title(f"{title} vs Number of Points", fontsize=14) + ax.set_title(f'{title} vs Number of Points', fontsize=14) ax.grid(True, alpha=0.3) - ax.set_xticks(xs) + ax.set_xticks(x) + plt.tight_layout() - plot_path = summary_path.with_name("performance_comparison.png") - fig.savefig(plot_path, dpi=150, bbox_inches="tight") + + plot_path = os.path.join(comparison_dir, "performance_comparison.png") + fig.savefig(plot_path, dpi=150, bbox_inches='tight') plt.close(fig) + + print(f"对比图已保存到: {plot_path}") + + return df_summary -def main() -> None: - logging.basicConfig(level=logging.INFO) - args = parse_args() - if args.task_file: - task = load_task_from_toml(args.task_file) - TaskRunner(task).run() - return - - comparison_data: Dict[int, pd.DataFrame] = {} +def main(): + """主函数""" + parser = argparse.ArgumentParser( + description="SAM2 点提示方式 - 多点数对比实验" + ) + + # 数据集参数 + parser.add_argument( + "--data_root", type=str, default="./crack500", + help="数据集根目录" + ) + parser.add_argument( + "--test_file", type=str, default="./crack500/test.txt", + help="测试集文件路径" + ) + + # 模型参数 + parser.add_argument( + "--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt", + help="SAM2 模型检查点路径" + ) + parser.add_argument( + "--model_cfg", type=str, default="sam2.1_hiera_s.yaml", + help="SAM2 模型配置文件" + ) + + # 实验参数 + parser.add_argument( + "--point_configs", type=int, nargs='+', default=[1, 3, 5], + help="要测试的点数配置" + ) + parser.add_argument( + "--per_component", action="store_true", + help="为每个连通域独立采样" + ) + parser.add_argument( + "--skip_comparison", action="store_true", + help="跳过对比分析" + ) + + args = parser.parse_args() + + print("=" * 80) + print("SAM2 点提示方式 - 多点数对比实验") + print("=" * 80) + print(f"数据集根目录: {args.data_root}") + print(f"测试集文件: {args.test_file}") + print(f"模型检查点: {args.checkpoint}") + print(f"点数配置: {args.point_configs}") + print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}") + print("=" * 80) + + # 检查 CUDA + import torch + if not torch.cuda.is_available(): + print("警告: CUDA 不可用,将使用 CPU(速度会很慢)") + else: + print(f"使用 GPU: {torch.cuda.get_device_name(0)}") + + # 运行所有实验 + results_dict = {} + for num_points in args.point_configs: - output_dir = default_output_dir(num_points, args.per_component) - task = build_task_for_points(args, num_points, output_dir) - if not task.steps: + try: + df_results = run_single_experiment( + data_root=args.data_root, + test_file=args.test_file, + checkpoint=args.checkpoint, + model_cfg=args.model_cfg, + num_points=num_points, + per_component=args.per_component + ) + results_dict[num_points] = df_results + + except Exception as e: + print(f"\n实验失败 ({num_points} 个点): {str(e)}") + import traceback + traceback.print_exc() continue - TaskRunner(task).run() - if not args.skip_comparison and not args.skip_evaluation: - df = load_results_csv(output_dir) - if df is not None: - comparison_data[num_points] = df - if not args.skip_comparison and comparison_data: - compare_results(comparison_data, args.comparison_dir) + + # 对比分析 + if not args.skip_comparison and len(results_dict) > 1: + try: + compare_results(results_dict) + except Exception as e: + print(f"\n对比分析失败: {str(e)}") + import traceback + traceback.print_exc() + + # 完成 + print("\n" + "=" * 80) + print("所有实验完成!") + print("=" * 80) + print("\n结果目录:") + for num_points in args.point_configs: + if args.per_component: + output_dir = f"./results/point_prompt_{num_points}pts_per_comp" + else: + output_dir = f"./results/point_prompt_{num_points}pts" + print(f" - {num_points} 个点: {output_dir}") + + if not args.skip_comparison and len(results_dict) > 1: + print(f" - 对比分析: ./results/point_comparison") + + print("=" * 80) if __name__ == "__main__": diff --git a/src/bbox_prompt.py b/src/bbox_prompt.py index 1f4b6f7..0ff4660 100644 --- a/src/bbox_prompt.py +++ b/src/bbox_prompt.py @@ -1,39 +1,158 @@ """ -边界框提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers) +边界框提示方式的 SAM2 裂缝分割实现 从 GT 掩码中提取边界框,使用 SAM2 进行分割 """ import os +import cv2 import numpy as np import torch from pathlib import Path -from typing import Dict, List +from typing import List, Tuple, Dict from tqdm import tqdm import json -import cv2 -from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask -from .hf_sam2_predictor import HFSam2Predictor -from .model.inference import predict_with_bbox_prompt +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor + + +def extract_bboxes_from_mask(mask: np.ndarray, expand_ratio: float = 0.0) -> List[np.ndarray]: + """ + 从二值掩码中提取所有连通域的边界框 + + Args: + mask: 二值掩码 (H, W),值为 0 或 255 + expand_ratio: 边界框扩展比例,例如 0.05 表示扩展 5% + + Returns: + List of bounding boxes in format [x1, y1, x2, y2] + """ + # 确保掩码是二值的 + binary_mask = (mask > 0).astype(np.uint8) + + # 连通域分析 + num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( + binary_mask, connectivity=8 + ) + + bboxes = [] + # 跳过背景 (label 0) + for i in range(1, num_labels): + x, y, w, h, area = stats[i] + + # 过滤太小的区域(可能是噪声) + if area < 10: # 最小面积阈值 + continue + + # 计算边界框 + x1, y1 = x, y + x2, y2 = x + w, y + h + + # 扩展边界框(如果需要) + if expand_ratio > 0: + cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 + w_new = w * (1 + expand_ratio) + h_new = h * (1 + expand_ratio) + x1 = max(0, int(cx - w_new / 2)) + y1 = max(0, int(cy - h_new / 2)) + x2 = min(mask.shape[1], int(cx + w_new / 2)) + y2 = min(mask.shape[0], int(cy + h_new / 2)) + + bboxes.append(np.array([x1, y1, x2, y2])) + + return bboxes + + +def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]: + """ + 加载图像和掩码 + + Args: + image_path: 图像路径 + mask_path: 掩码路径 + + Returns: + image: RGB 图像 (H, W, 3) + mask: 二值掩码 (H, W) + """ + # 加载图像 + image = cv2.imread(image_path) + if image is None: + raise ValueError(f"无法加载图像: {image_path}") + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # 加载掩码 + mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) + if mask is None: + raise ValueError(f"无法加载掩码: {mask_path}") + + return image, mask + + +def predict_with_bbox_prompt( + predictor: SAM2ImagePredictor, + image: np.ndarray, + bboxes: List[np.ndarray] +) -> np.ndarray: + """ + 使用边界框提示进行 SAM2 预测 + + Args: + predictor: SAM2ImagePredictor 实例 + image: RGB 图像 (H, W, 3) + bboxes: 边界框列表,每个为 [x1, y1, x2, y2] + + Returns: + combined_mask: 合并后的预测掩码 (H, W) + """ + # 设置图像 + predictor.set_image(image) + + # 如果没有边界框,返回空掩码 + if len(bboxes) == 0: + return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + + # 合并所有预测掩码 + combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + + # 对每个边界框进行预测 + for bbox in bboxes: + masks, scores, logits = predictor.predict( + point_coords=None, + point_labels=None, + box=bbox[None, :], # shape: (1, 4) + multimask_output=False, + ) + + # 取第一个掩码(因为 multimask_output=False) + mask = masks[0] # shape: (H, W) + + # 合并到总掩码中 + combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8) + + # 转换为 0-255 + combined_mask = combined_mask * 255 + + return combined_mask def process_test_set( data_root: str, test_file: str, - predictor: HFSam2Predictor, + predictor: SAM2ImagePredictor, output_dir: str, expand_ratio: float = 0.0 ) -> List[Dict]: """ 处理整个测试集 - + Args: data_root: 数据集根目录 test_file: 测试集文件路径 (test.txt) - predictor: HFSam2Predictor 实例 + predictor: SAM2ImagePredictor 实例 output_dir: 输出目录 expand_ratio: 边界框扩展比例 - + Returns: results: 包含每个样本信息的列表 """ @@ -41,26 +160,26 @@ def process_test_set( os.makedirs(output_dir, exist_ok=True) pred_dir = os.path.join(output_dir, "predictions") os.makedirs(pred_dir, exist_ok=True) - + # 读取测试集文件 with open(test_file, 'r') as f: lines = f.readlines() - + results = [] - + print(f"开始处理 {len(lines)} 张测试图像...") - + for line in tqdm(lines, desc="处理测试集"): parts = line.strip().split() if len(parts) != 2: continue - + img_rel_path, mask_rel_path = parts - + # 构建完整路径 img_path = os.path.join(data_root, img_rel_path) mask_path = os.path.join(data_root, mask_rel_path) - + # 检查文件是否存在 if not os.path.exists(img_path): print(f"警告: 图像不存在 {img_path}") @@ -68,23 +187,23 @@ def process_test_set( if not os.path.exists(mask_path): print(f"警告: 掩码不存在 {mask_path}") continue - + try: # 加载图像和掩码 image, mask_gt = load_image_and_mask(img_path, mask_path) - + # 从 GT 掩码提取边界框 bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio) - + # 使用 SAM2 预测 - with torch.inference_mode(): + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): mask_pred = predict_with_bbox_prompt(predictor, image, bboxes) - + # 保存预测掩码 img_name = Path(img_rel_path).stem pred_path = os.path.join(pred_dir, f"{img_name}_pred.png") cv2.imwrite(pred_path, mask_pred) - + # 记录结果 results.append({ "image_path": img_rel_path, @@ -93,23 +212,20 @@ def process_test_set( "num_bboxes": len(bboxes), "image_shape": image.shape[:2], }) - + except Exception as e: print(f"处理失败 {img_path}: {str(e)}") - # print stack trace - import traceback - traceback.print_exc() continue - + # 保存结果信息 results_file = os.path.join(output_dir, "results_info.json") with open(results_file, 'w') as f: json.dump(results, f, indent=2) - + print(f"\n处理完成!共处理 {len(results)} 张图像") print(f"预测掩码保存在: {pred_dir}") print(f"结果信息保存在: {results_file}") - + return results @@ -118,36 +234,38 @@ def main(): # 配置参数 DATA_ROOT = "./crack500" TEST_FILE = "./crack500/test.txt" - OUTPUT_DIR = "./results/bbox_prompt_hf" - - # HuggingFace SAM2 模型 - MODEL_ID = "facebook/sam2-hiera-small" - + OUTPUT_DIR = "./results/bbox_prompt" + + # SAM2 模型配置 + CHECKPOINT = "./sam2/checkpoints/sam2.1_hiera_small.pt" + MODEL_CFG = "sam2.1_hiera_s.yaml" + # 边界框扩展比例 EXPAND_RATIO = 0.05 # 5% 扩展 - + print("=" * 60) - print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估") + print("SAM2 边界框提示方式 - Crack500 数据集评估") print("=" * 60) print(f"数据集根目录: {DATA_ROOT}") print(f"测试集文件: {TEST_FILE}") - print(f"模型: {MODEL_ID}") + print(f"模型检查点: {CHECKPOINT}") + print(f"模型配置: {MODEL_CFG}") print(f"边界框扩展比例: {EXPAND_RATIO * 100}%") print(f"输出目录: {OUTPUT_DIR}") print("=" * 60) - + # 检查 CUDA 是否可用 if not torch.cuda.is_available(): print("警告: CUDA 不可用,将使用 CPU(速度会很慢)") else: print(f"使用 GPU: {torch.cuda.get_device_name(0)}") - - # 构建 SAM2 predictor + + # 构建 SAM2 模型 print("\n加载 SAM2 模型...") - from .hf_sam2_predictor import build_hf_sam2_predictor - predictor = build_hf_sam2_predictor(model_id=MODEL_ID) + sam2_model = build_sam2(MODEL_CFG, CHECKPOINT) + predictor = SAM2ImagePredictor(sam2_model) print("模型加载完成!") - + # 处理测试集 results = process_test_set( data_root=DATA_ROOT, @@ -156,7 +274,7 @@ def main(): output_dir=OUTPUT_DIR, expand_ratio=EXPAND_RATIO ) - + print("\n" + "=" * 60) print("处理完成!接下来请运行评估脚本计算指标。") print("=" * 60) diff --git a/src/dataset/__init__.py b/src/dataset/__init__.py deleted file mode 100644 index 4a02eb4..0000000 --- a/src/dataset/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples -from .registry import DatasetRegistry -from .utils import extract_bboxes_from_mask, load_image_and_mask - -# ensure built-in datasets register themselves -from . import crack500 # noqa: F401 - -__all__ = [ - "BaseDataset", - "DatasetRecord", - "ModelReadySample", - "collate_samples", - "DatasetRegistry", - "extract_bboxes_from_mask", - "load_image_and_mask", -] diff --git a/src/dataset/base.py b/src/dataset/base.py deleted file mode 100644 index ace1da1..0000000 --- a/src/dataset/base.py +++ /dev/null @@ -1,167 +0,0 @@ -from __future__ import annotations - -import abc -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional - -import numpy as np -from PIL import Image -import torch -from torch.utils.data import Dataset - -from ..model_configuration.config import DatasetConfig - - -@dataclass -class DatasetRecord: - """ - Lightweight description of a single sample on disk. - """ - - image_path: Path - mask_path: Optional[Path] = None - prompt_path: Optional[Path] = None - metadata: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class ModelReadySample: - """ - Standard container that mirrors what Hugging Face pipelines expect. - """ - - pixel_values: torch.Tensor | np.ndarray - prompts: Dict[str, Any] = field(default_factory=dict) - labels: Dict[str, Any] = field(default_factory=dict) - metadata: Dict[str, Any] = field(default_factory=dict) - - def to_hf_dict(self) -> Dict[str, Any]: - payload = { - "pixel_values": self.pixel_values, - "metadata": self.metadata, - } - if self.prompts: - payload["prompts"] = self.prompts - if self.labels: - payload["labels"] = self.labels - return payload - - -class BaseDataset(Dataset): - """ - Common dataset base class that handles record bookkeeping, IO, and - formatting tensors for Hugging Face pipelines. - """ - - dataset_name: str = "base" - - def __init__( - self, - config: DatasetConfig, - transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None, - return_hf_dict: bool = True, - ) -> None: - self.config = config - self.transforms = transforms - self.return_hf_dict = return_hf_dict - self.records: List[DatasetRecord] = self.load_records() - - def __len__(self) -> int: - return len(self.records) - - def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample: - record = self.records[index] - sample = self.prepare_sample(record) - if self.transforms: - sample = self.transforms(sample) - return sample.to_hf_dict() if self.return_hf_dict else sample - - @abc.abstractmethod - def load_records(self) -> List[DatasetRecord]: - """ - Scan the dataset directory / annotation files and return - structured references to each item on disk. - """ - - def prepare_sample(self, record: DatasetRecord) -> ModelReadySample: - """ - Load image/mask/prompt data from disk and wrap it inside ModelReadySample. - Subclasses can override this to implement custom augmentations or prompt generation. - """ - image = self._load_image(record.image_path) - mask = ( - self._load_mask(record.mask_path) - if record.mask_path is not None - else None - ) - prompts = self.build_prompts(record, mask) - labels = {"mask": mask} if mask is not None else {} - sample = ModelReadySample( - pixel_values=image, - prompts=prompts, - labels=labels, - metadata=record.metadata, - ) - return sample - - def build_prompts( - self, record: DatasetRecord, mask: Optional[np.ndarray] - ) -> Dict[str, Any]: - """ - Derive prompts from metadata or masks. - Default implementation extracts bounding boxes from masks. - """ - if mask is None: - return {} - boxes = self._mask_to_bboxes(mask) - return {"boxes": boxes} - - def _load_image(self, path: Path) -> np.ndarray: - image = Image.open(path).convert("RGB") - return np.array(image) - - def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]: - if path is None: - return None - mask = Image.open(path).convert("L") - return np.array(mask) - - def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]: - """ - Helper that mirrors the legacy bbox extraction pipeline. - """ - if mask.ndim != 2: - raise ValueError("Mask must be 2-dimensional.") - ys, xs = np.where(mask > 0) - if ys.size == 0: - return [] - x_min, x_max = xs.min(), xs.max() - y_min, y_max = ys.min(), ys.max() - return [[int(x_min), int(y_min), int(x_max), int(y_max)]] - - -def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]: - """ - Default collate_fn that merges ModelReadySample/HF dict outputs. - """ - pixel_values = [] - prompts: List[Dict[str, Any]] = [] - labels: List[Dict[str, Any]] = [] - metadata: List[Dict[str, Any]] = [] - for item in batch: - if isinstance(item, ModelReadySample): - payload = item.to_hf_dict() - else: - payload = item - pixel_values.append(payload["pixel_values"]) - prompts.append(payload.get("prompts", {})) - labels.append(payload.get("labels", {})) - metadata.append(payload.get("metadata", {})) - stacked = { - "pixel_values": torch.as_tensor(np.stack(pixel_values)), - "prompts": prompts, - "labels": labels, - "metadata": metadata, - } - return stacked diff --git a/src/dataset/crack500.py b/src/dataset/crack500.py deleted file mode 100644 index cd9a749..0000000 --- a/src/dataset/crack500.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import numpy as np - -from .base import BaseDataset, DatasetRecord -from .registry import DatasetRegistry -from .utils import ( - extract_bboxes_from_mask, - sample_points_on_skeleton, - sample_points_per_component, -) -from ..model_configuration.config import DatasetConfig - - -@DatasetRegistry.register("crack500") -class Crack500Dataset(BaseDataset): - """ - Reference implementation that loads Crack500 samples from an image list. - """ - - def __init__( - self, - config: DatasetConfig, - expand_ratio: float = 0.05, - min_area: int = 10, - **kwargs, - ) -> None: - extra = dict(config.extra_params or {}) - expand_ratio = float(extra.get("expand_ratio", expand_ratio)) - self.prompt_mode = extra.get("prompt_mode", "bbox") - self.num_points = int(extra.get("num_points", 5)) - self.per_component = bool(extra.get("per_component", False)) - self.expand_ratio = expand_ratio - self.min_area = min_area - super().__init__(config, **kwargs) - - def load_records(self) -> List[DatasetRecord]: - base_dir = Path(self.config.data_root) - list_file = ( - Path(self.config.annotation_file) - if self.config.annotation_file - else base_dir / (self.config.split_file or "test.txt") - ) - if not list_file.exists(): - raise FileNotFoundError(f"Missing Crack500 split file: {list_file}") - image_dir = base_dir / (self.config.image_folder or "testcrop") - mask_dir = base_dir / (self.config.mask_folder or "testdata") - records: List[DatasetRecord] = [] - with list_file.open("r", encoding="utf-8") as handle: - for line in handle: - image_name = line.strip() - if not image_name: - continue - image_path = image_dir / image_name - mask_name = image_name.replace(".jpg", ".png") - mask_path = mask_dir / mask_name - metadata = {"split": self.config.split, "image_name": image_name} - records.append( - DatasetRecord( - image_path=image_path, - mask_path=mask_path if mask_path.exists() else None, - metadata=metadata, - ) - ) - if not records: - raise RuntimeError( - f"No records found in {image_dir} for split {self.config.split}" - ) - return records - - def build_prompts( - self, - record: DatasetRecord, - mask: Optional[np.ndarray], - ) -> Dict[str, List[List[int]]]: - if mask is None: - return {} - if self.prompt_mode == "point": - points, point_labels = self._build_point_prompts(mask) - if points.size == 0: - return {} - prompts: Dict[str, List[List[int]]] = {"points": points.tolist()} - if point_labels.size > 0: - prompts["point_labels"] = point_labels.tolist() - return prompts - boxes = extract_bboxes_from_mask( - mask, expand_ratio=self.expand_ratio, min_area=self.min_area - ) - return {"boxes": boxes} - - def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: - if self.per_component: - return sample_points_per_component(mask, self.num_points) - points = sample_points_on_skeleton(mask, self.num_points) - labels = np.ones(points.shape[0], dtype=np.int32) - return points, labels diff --git a/src/dataset/registry.py b/src/dataset/registry.py deleted file mode 100644 index 558d416..0000000 --- a/src/dataset/registry.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Type - -from .base import BaseDataset - - -class DatasetRegistry: - """ - Simple registry so configs can refer to datasets by string key. - """ - - _registry: Dict[str, Type[BaseDataset]] = {} - - @classmethod - def register(cls, name: str): - def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]: - cls._registry[name] = dataset_cls - dataset_cls.dataset_name = name - return dataset_cls - - return decorator - - @classmethod - def create(cls, name: str, *args, **kwargs) -> BaseDataset: - if name not in cls._registry: - raise KeyError(f"Dataset '{name}' is not registered.") - dataset_cls = cls._registry[name] - return dataset_cls(*args, **kwargs) - - @classmethod - def available(cls) -> Dict[str, Type[BaseDataset]]: - return dict(cls._registry) diff --git a/src/dataset/utils.py b/src/dataset/utils.py deleted file mode 100644 index 184ac43..0000000 --- a/src/dataset/utils.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import List, Tuple - -import cv2 -import numpy as np -from skimage.morphology import skeletonize - - -def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]: - """ - Reads an RGB image and its mask counterpart. - """ - image_path = str(image_path) - mask_path = str(mask_path) - image = cv2.imread(image_path) - if image is None: - raise ValueError(f"无法加载图像: {image_path}") - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) - if mask is None: - raise ValueError(f"无法加载掩码: {mask_path}") - return image, mask - - -def extract_bboxes_from_mask( - mask: np.ndarray, - expand_ratio: float = 0.0, - min_area: int = 10, -) -> List[List[int]]: - """ - Extract bounding boxes from a binary mask using connected components. - """ - binary_mask = (mask > 0).astype(np.uint8) - num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8) - bboxes: List[List[int]] = [] - for i in range(1, num_labels): - x, y, w, h, area = stats[i] - if area < min_area: - continue - x1, y1 = x, y - x2, y2 = x + w, y + h - if expand_ratio > 0: - cx, cy = (x1 + x2) / 2, (y1 + y2) / 2 - w_new = w * (1 + expand_ratio) - h_new = h * (1 + expand_ratio) - x1 = max(0, int(cx - w_new / 2)) - y1 = max(0, int(cy - h_new / 2)) - x2 = min(mask.shape[1], int(cx + w_new / 2)) - y2 = min(mask.shape[0], int(cy + h_new / 2)) - bboxes.append([x1, y1, x2, y2]) - return bboxes - - -def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray: - """ - Sample points uniformly along the mask skeleton in (x, y) order. - """ - binary_mask = (mask > 0).astype(bool) - try: - skeleton = skeletonize(binary_mask) - except Exception: - skeleton = binary_mask - coords = np.argwhere(skeleton) - if coords.size == 0: - return np.zeros((0, 2), dtype=np.int32) - if coords.shape[0] <= num_points: - points = coords[:, [1, 0]] - return points.astype(np.int32) - indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int) - sampled = coords[indices][:, [1, 0]] - return sampled.astype(np.int32) - - -def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]: - """ - Sample points per connected component along each component's skeleton. - """ - num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8)) - all_points = [] - for region_id in range(1, num_labels): - region_mask = (labels_map == region_id).astype(np.uint8) * 255 - points = sample_points_on_skeleton(region_mask, num_points_per_component) - if len(points): - all_points.append(points) - if not all_points: - return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32) - stacked = np.vstack(all_points) - labels = np.ones(stacked.shape[0], dtype=np.int32) - return stacked, labels diff --git a/src/legacy_evaluation.py b/src/evaluation.py similarity index 100% rename from src/legacy_evaluation.py rename to src/evaluation.py diff --git a/src/evaluation/__init__.py b/src/evaluation/__init__.py deleted file mode 100644 index 3b51e5c..0000000 --- a/src/evaluation/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall -from .pipeline_eval import PipelineEvaluator -from .reporting import write_csv, write_json - -__all__ = [ - "METRIC_REGISTRY", - "PipelineEvaluator", - "compute_dice", - "compute_iou", - "compute_precision", - "compute_recall", - "write_csv", - "write_json", -] diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py deleted file mode 100644 index d015346..0000000 --- a/src/evaluation/metrics.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -from typing import Callable, Dict, Iterable, Tuple - -import numpy as np - - -def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float: - pred_bin = (pred >= threshold).astype(np.uint8) - target_bin = (target > 0).astype(np.uint8) - intersection = (pred_bin & target_bin).sum() - union = (pred_bin | target_bin).sum() - return float(intersection / union) if union else 0.0 - - -def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float: - pred_bin = (pred >= threshold).astype(np.uint8) - target_bin = (target > 0).astype(np.uint8) - intersection = (pred_bin & target_bin).sum() - total = pred_bin.sum() + target_bin.sum() - return float((2 * intersection) / total) if total else 0.0 - - -def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float: - pred_bin = (pred >= threshold).astype(np.uint8) - target_bin = (target > 0).astype(np.uint8) - tp = (pred_bin & target_bin).sum() - fp = (pred_bin & (1 - target_bin)).sum() - return float(tp / (tp + fp)) if (tp + fp) else 0.0 - - -def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float: - pred_bin = (pred >= threshold).astype(np.uint8) - target_bin = (target > 0).astype(np.uint8) - tp = (pred_bin & target_bin).sum() - fn = ((1 - pred_bin) & target_bin).sum() - return float(tp / (tp + fn)) if (tp + fn) else 0.0 - - -MetricFn = Callable[[np.ndarray, np.ndarray, float], float] - - -METRIC_REGISTRY: Dict[str, MetricFn] = { - "iou": compute_iou, - "dice": compute_dice, - "precision": compute_precision, - "recall": compute_recall, -} - - -def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]: - resolved: Dict[str, MetricFn] = {} - for name in metric_names: - if name not in METRIC_REGISTRY: - raise KeyError(f"Metric '{name}' is not registered.") - resolved[name] = METRIC_REGISTRY[name] - return resolved diff --git a/src/evaluation/pipeline_eval.py b/src/evaluation/pipeline_eval.py deleted file mode 100644 index 8a06d78..0000000 --- a/src/evaluation/pipeline_eval.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from typing import Any, Dict, List, Optional - -import numpy as np -from tqdm import tqdm - -from ..dataset import BaseDataset -from ..model import BaseModelAdapter -from ..model_configuration import EvaluationConfig -from .metrics import resolve_metrics -from .utils import extract_mask_from_pipeline_output - - -class PipelineEvaluator: - """ - Runs a Hugging Face pipeline across a dataset and aggregates metrics. - """ - - def __init__( - self, - dataset: BaseDataset, - adapter: BaseModelAdapter, - config: EvaluationConfig, - ) -> None: - self.dataset = dataset - self.adapter = adapter - self.config = config - self.metrics = resolve_metrics(config.metrics) - - def run(self) -> Dict[str, Any]: - pipe = self.adapter.build_pipeline() - aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics} - output_dir = Path(self.config.output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - requested = self.config.max_samples or len(self.dataset) - total = min(requested, len(self.dataset)) - prog_bar = tqdm(range(total), total=total) - for idx in prog_bar: - sample = self.dataset[idx] - inputs = self._build_pipeline_inputs(sample) - preds = pipe(**inputs) - labels = sample.get("labels", {}) - mask = labels.get("mask") - if mask is None: - continue - prediction_mask = self._extract_mask(preds) - for metric_name, metric_fn in self.metrics.items(): - for threshold in self.config.thresholds: - value = metric_fn(prediction_mask, mask, threshold) - aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value) - if self.config.save_predictions: - self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"]) - summary = { - "metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()}, - "config": self.config.__dict__, - "num_samples": total, - } - with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle: - json.dump(summary, handle, indent=2) - return summary - - def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]: - inputs: Dict[str, Any] = {"images": sample["pixel_values"]} - prompts = sample.get("prompts", {}) - if "boxes" in prompts and prompts["boxes"]: - inputs["boxes"] = prompts["boxes"] - if "points" in prompts and prompts["points"]: - inputs["points"] = prompts["points"] - if "point_labels" in prompts and prompts["point_labels"]: - inputs["point_labels"] = prompts["point_labels"] - return inputs - - def _extract_mask(self, pipeline_output: Any) -> np.ndarray: - """ - Normalize pipeline outputs into numpy masks. - """ - return extract_mask_from_pipeline_output(pipeline_output) - - def _write_prediction( - self, - output_dir: Path, - index: int, - mask: np.ndarray, - metadata: Optional[Dict[str, Any]], - ) -> None: - if metadata and "image_name" in metadata: - filename = metadata["image_name"].replace(".jpg", "_pred.npy") - else: - filename = f"sample_{index:04d}_pred.npy" - target_path = output_dir / "predictions" - target_path.mkdir(parents=True, exist_ok=True) - np.save(target_path / filename, mask) diff --git a/src/evaluation/reporting.py b/src/evaluation/reporting.py deleted file mode 100644 index 38d8daf..0000000 --- a/src/evaluation/reporting.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -import csv -import json -from pathlib import Path -from typing import Dict, Iterable - - -def write_json(summary: Dict, output_path: Path) -> None: - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("w", encoding="utf-8") as handle: - json.dump(summary, handle, indent=2) - - -def write_csv(rows: Iterable[Dict], output_path: Path) -> None: - rows = list(rows) - if not rows: - return - output_path.parent.mkdir(parents=True, exist_ok=True) - fieldnames = sorted(rows[0].keys()) - with output_path.open("w", encoding="utf-8", newline="") as handle: - writer = csv.DictWriter(handle, fieldnames=fieldnames) - writer.writeheader() - for row in rows: - writer.writerow(row) diff --git a/src/evaluation/run_pipeline.py b/src/evaluation/run_pipeline.py deleted file mode 100644 index 8eaa9bf..0000000 --- a/src/evaluation/run_pipeline.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass, replace -from typing import Optional - -from transformers import HfArgumentParser - -from ..dataset import DatasetRegistry -from ..model import ModelRegistry -from ..model_configuration import ConfigRegistry, EvaluationConfig -from .pipeline_eval import PipelineEvaluator - -LOGGER = logging.getLogger(__name__) - - -@dataclass -class PipelineCLIArguments: - config_name: str = "sam2_bbox_prompt" - model_key: str = "sam2" - split: str = "test" - split_file: Optional[str] = None - device: Optional[str] = None - max_samples: Optional[int] = None - - -def main() -> None: - parser = HfArgumentParser(PipelineCLIArguments) - (cli_args,) = parser.parse_args_into_dataclasses() - project_config = ConfigRegistry.get(cli_args.config_name) - dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file) - dataset = DatasetRegistry.create( - dataset_cfg.name, - config=dataset_cfg, - return_hf_dict=True, - ) - adapter = ModelRegistry.create(cli_args.model_key, project_config.model) - evaluation_config = replace( - project_config.evaluation, - max_samples=cli_args.max_samples, - ) - if cli_args.device: - adapter.build_pipeline(device=cli_args.device) - evaluator = PipelineEvaluator( - dataset=dataset, - adapter=adapter, - config=evaluation_config, - ) - summary = evaluator.run() - LOGGER.info("Evaluation summary: %s", summary) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main() diff --git a/src/evaluation/utils.py b/src/evaluation/utils.py deleted file mode 100644 index 8aac66e..0000000 --- a/src/evaluation/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import numpy as np - - -def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray: - if isinstance(pipeline_output, list): - pipeline_output = pipeline_output[0] - mask = pipeline_output.get("mask") - if mask is None: - raise ValueError("Pipeline output missing 'mask'.") - if isinstance(mask, np.ndarray): - return mask - return np.array(mask) diff --git a/src/hf_sam2_predictor.py b/src/hf_sam2_predictor.py deleted file mode 100644 index d91f387..0000000 --- a/src/hf_sam2_predictor.py +++ /dev/null @@ -1,7 +0,0 @@ -""" -Backward-compatible wrapper that re-exports the predictor relocated to src.model. -""" - -from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor - -__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"] diff --git a/src/model/__init__.py b/src/model/__init__.py deleted file mode 100644 index dfed61c..0000000 --- a/src/model/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .base import BaseModelAdapter -from .inference import predict_with_bbox_prompt -from .predictor import HFSam2Predictor, build_hf_sam2_predictor -from .registry import ModelRegistry -from .sam2_adapter import Sam2ModelAdapter -from .trainer import FineTuningTrainer, TrainerArtifacts - -__all__ = [ - "BaseModelAdapter", - "FineTuningTrainer", - "HFSam2Predictor", - "ModelRegistry", - "Sam2ModelAdapter", - "TrainerArtifacts", - "build_hf_sam2_predictor", - "predict_with_bbox_prompt", -] diff --git a/src/model/base.py b/src/model/base.py deleted file mode 100644 index ae6e362..0000000 --- a/src/model/base.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import abc -from typing import Any, Dict, Optional - -from transformers import pipeline - -from ..model_configuration import ModelConfig - - -class BaseModelAdapter(abc.ABC): - """ - Thin wrapper that standardizes how we instantiate models/processors/pipelines. - """ - - task: str = "image-segmentation" - - def __init__(self, config: ModelConfig) -> None: - self.config = config - self._model = None - self._processor = None - self._pipeline = None - - def load_pretrained(self): - if self._model is None or self._processor is None: - self._model, self._processor = self._load_pretrained() - return self._model, self._processor - - def build_pipeline( - self, - device: Optional[str] = None, - **kwargs, - ): - if self._pipeline is None: - model, processor = self.load_pretrained() - pipe_kwargs = { - "task": self.task, - "model": model, - "image_processor": processor, - **self.config.pipeline_kwargs, - **kwargs, - } - if device is not None: - pipe_kwargs["device"] = device - self._pipeline = self._create_pipeline(pipe_kwargs) - return self._pipeline - - async def build_pipeline_async(self, **kwargs): - """ - Async helper for future multi-device orchestration. - """ - return self.build_pipeline(**kwargs) - - def save_pretrained(self, output_dir: str) -> None: - model, processor = self.load_pretrained() - model.save_pretrained(output_dir) - processor.save_pretrained(output_dir) - - @abc.abstractmethod - def _load_pretrained(self): - """ - Return (model, processor) tuple. - """ - - def _create_pipeline(self, pipe_kwargs: Dict[str, Any]): - return pipeline(**pipe_kwargs) diff --git a/src/model/inference.py b/src/model/inference.py deleted file mode 100644 index 0548f38..0000000 --- a/src/model/inference.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -from typing import List - -import numpy as np - -from .predictor import HFSam2Predictor - - -def predict_with_bbox_prompt( - predictor: HFSam2Predictor, - image: np.ndarray, - bboxes: List[np.ndarray], -) -> np.ndarray: - """ - Run SAM2 predictions for each bounding box and merge the masks. - """ - predictor.set_image(image) - if not bboxes: - return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - for bbox in bboxes: - masks, _, _ = predictor.predict( - point_coords=None, - point_labels=None, - box=bbox, - multimask_output=False, - ) - mask = masks[0] - combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8) - combined_mask = combined_mask * 255 - return combined_mask diff --git a/src/model/predictor.py b/src/model/predictor.py deleted file mode 100644 index ae2befc..0000000 --- a/src/model/predictor.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import annotations - -import json -from pathlib import Path -from typing import Optional, Tuple - -import numpy as np -import torch -from PIL import Image -from transformers import SamModel, SamProcessor - - -class HFSam2Predictor: - """ - Predictor wrapper around Hugging Face SAM2 models. - """ - - def __init__( - self, - model_id: str = "facebook/sam2-hiera-small", - device: Optional[str] = None, - dtype: torch.dtype = torch.bfloat16, - ) -> None: - self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") - self.dtype = dtype - self.model = SamModel.from_pretrained(model_id).to(self.device) - self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json") - self._override_processor_config() - if dtype == torch.bfloat16: - self.model = self.model.to(dtype=dtype) - self.model.eval() - self.current_image = None - self.current_image_embeddings = None - - def set_image(self, image: np.ndarray) -> None: - if isinstance(image, np.ndarray): - pil_image = Image.fromarray(image.astype(np.uint8)) - else: - pil_image = image - self.current_image = pil_image - with torch.inference_mode(): - inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device) - if self.dtype == torch.bfloat16: - inputs = { - k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v - for k, v in inputs.items() - } - self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"]) - - def predict( - self, - point_coords: Optional[np.ndarray] = None, - point_labels: Optional[np.ndarray] = None, - box: Optional[np.ndarray] = None, - multimask_output: bool = False, - ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - if self.current_image is None: - raise ValueError("No image set. Call set_image() first.") - input_points = self._prepare_points(point_coords) - input_labels = self._prepare_labels(point_labels) - input_boxes = self._prepare_boxes(box) - with torch.inference_mode(): - inputs = self.processor( - images=self.current_image, - input_points=input_points, - input_labels=input_labels, - input_boxes=input_boxes, - return_tensors="pt", - ).to(self.device) - if self.dtype == torch.bfloat16: - inputs = { - k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v - for k, v in inputs.items() - } - inputs.pop("pixel_values", None) - inputs["image_embeddings"] = self.current_image_embeddings - outputs = self.model(**inputs, multimask_output=multimask_output) - masks = self.processor.image_processor.post_process_masks( - outputs.pred_masks.float().cpu(), - inputs["original_sizes"].cpu(), - inputs["reshaped_input_sizes"].cpu(), - )[0] - scores = outputs.iou_scores.float().cpu().numpy()[0] - masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8) - logits = outputs.pred_masks.float().cpu().numpy()[0] - return masks_np, scores, logits - - def _prepare_points(self, coords: Optional[np.ndarray]): - """ - Points must be shaped (num_points, 2); wrap in outer batch dimension. - """ - if coords is None: - return None - coords_arr = np.asarray(coords) - if coords_arr.ndim == 1: - coords_arr = coords_arr[None, :] - if coords_arr.ndim != 2: - raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.") - return [coords_arr.tolist()] - - def _prepare_labels(self, labels: Optional[np.ndarray]): - """ - Labels mirror the point dimension and are shaped (num_points,). - """ - if labels is None: - return None - labels_arr = np.asarray(labels) - if labels_arr.ndim == 0: - labels_arr = labels_arr[None] - if labels_arr.ndim != 1: - raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.") - return [labels_arr.tolist()] - - def _prepare_boxes(self, boxes: Optional[np.ndarray]): - """ - HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4). - """ - if boxes is None: - return None - boxes_arr = np.asarray(boxes) - if boxes_arr.ndim == 1: - return [[boxes_arr.tolist()]] - if boxes_arr.ndim == 2: - return [boxes_arr.tolist()] - if boxes_arr.ndim == 3: - return boxes_arr.tolist() - raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.") - - def _override_processor_config(self) -> None: - """ - Override processor config with local settings to avoid upstream regressions. - """ - config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json" - if not config_path.exists(): - return - try: - config_dict = json.loads(config_path.read_text()) - except Exception: - return - image_processor = getattr(self.processor, "image_processor", None) - if image_processor is None or not hasattr(image_processor, "config"): - return - # config behaves like a dict; update in-place. - try: - image_processor.config.update(config_dict) - except Exception: - for key, value in config_dict.items(): - try: - setattr(image_processor.config, key, value) - except Exception: - continue - - -def build_hf_sam2_predictor( - model_id: str = "facebook/sam2-hiera-small", - device: Optional[str] = None, -) -> HFSam2Predictor: - return HFSam2Predictor(model_id=model_id, device=device) diff --git a/src/model/registry.py b/src/model/registry.py deleted file mode 100644 index f2139fe..0000000 --- a/src/model/registry.py +++ /dev/null @@ -1,33 +0,0 @@ -from __future__ import annotations - -from typing import Dict, Type - -from ..model_configuration import ModelConfig -from .base import BaseModelAdapter - - -class ModelRegistry: - """ - Maps model keys to adapter classes so configs can reference them declaratively. - """ - - _registry: Dict[str, Type[BaseModelAdapter]] = {} - - @classmethod - def register(cls, name: str): - def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]: - cls._registry[name] = adapter_cls - return adapter_cls - - return decorator - - @classmethod - def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter: - if name not in cls._registry: - raise KeyError(f"ModelAdapter '{name}' is not registered.") - adapter_cls = cls._registry[name] - return adapter_cls(config) - - @classmethod - def available(cls) -> Dict[str, Type[BaseModelAdapter]]: - return dict(cls._registry) diff --git a/src/model/sam2_adapter.py b/src/model/sam2_adapter.py deleted file mode 100644 index c2be204..0000000 --- a/src/model/sam2_adapter.py +++ /dev/null @@ -1,35 +0,0 @@ -from __future__ import annotations - -from typing import Any, Tuple - -from transformers import AutoModelForImageSegmentation, AutoProcessor - -from ..model_configuration import ModelConfig -from .base import BaseModelAdapter -from .registry import ModelRegistry - - -@ModelRegistry.register("sam2") -class Sam2ModelAdapter(BaseModelAdapter): - """ - Adapter that exposes SAM2 checkpoints through the HF pipeline interface. - """ - - def __init__(self, config: ModelConfig) -> None: - super().__init__(config) - self.task = "image-segmentation" - - def _load_pretrained(self) -> Tuple[Any, Any]: - model = AutoModelForImageSegmentation.from_pretrained( - self.config.name_or_path, - revision=self.config.revision, - cache_dir=self.config.cache_dir, - trust_remote_code=True, - ) - processor = AutoProcessor.from_pretrained( - self.config.name_or_path, - revision=self.config.revision, - cache_dir=self.config.cache_dir, - trust_remote_code=True, - ) - return model, processor diff --git a/src/model/train_hf.py b/src/model/train_hf.py deleted file mode 100644 index 22654f6..0000000 --- a/src/model/train_hf.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass, replace -from typing import Optional - -from transformers import HfArgumentParser - -from ..dataset import DatasetRegistry -from ..model_configuration import ConfigRegistry, DatasetConfig -from .registry import ModelRegistry -from .trainer import FineTuningTrainer - -LOGGER = logging.getLogger(__name__) - - -@dataclass -class TrainCLIArguments: - config_name: str = "sam2_bbox_prompt" - model_key: str = "sam2" - train_split: str = "train" - eval_split: str = "val" - train_split_file: Optional[str] = None - eval_split_file: Optional[str] = None - skip_eval: bool = False - device: Optional[str] = None - - -def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig: - overrides = {} - if split: - overrides["split"] = split - if split_file: - overrides["split_file"] = split_file - return replace(config, **overrides) - - -def main() -> None: - parser = HfArgumentParser(TrainCLIArguments) - (cli_args,) = parser.parse_args_into_dataclasses() - project_config = ConfigRegistry.get(cli_args.config_name) - train_dataset_cfg = build_dataset( - project_config.dataset, cli_args.train_split, cli_args.train_split_file - ) - eval_dataset_cfg = ( - build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file) - if not cli_args.skip_eval - else None - ) - - train_dataset = DatasetRegistry.create( - train_dataset_cfg.name, - config=train_dataset_cfg, - return_hf_dict=True, - ) - eval_dataset = ( - DatasetRegistry.create( - eval_dataset_cfg.name, - config=eval_dataset_cfg, - return_hf_dict=True, - ) - if eval_dataset_cfg - else None - ) - - adapter = ModelRegistry.create(cli_args.model_key, project_config.model) - if cli_args.device: - adapter.build_pipeline(device=cli_args.device) - - trainer_builder = FineTuningTrainer( - adapter=adapter, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - training_config=project_config.training, - ) - artifacts = trainer_builder.build() - LOGGER.info("Starting training with args: %s", artifacts.training_args) - train_result = artifacts.trainer.train() - LOGGER.info("Training finished: %s", train_result) - artifacts.trainer.save_model(project_config.training.output_dir) - if eval_dataset and not cli_args.skip_eval: - metrics = artifacts.trainer.evaluate() - LOGGER.info("Evaluation metrics: %s", metrics) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main() diff --git a/src/model/trainer.py b/src/model/trainer.py deleted file mode 100644 index 6973a02..0000000 --- a/src/model/trainer.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Any, Dict, Optional - -from transformers import Trainer, TrainingArguments - -from ..dataset import BaseDataset, collate_samples -from ..model_configuration import TrainingConfig -from .base import BaseModelAdapter - - -@dataclass -class TrainerArtifacts: - trainer: Trainer - training_args: TrainingArguments - - -class FineTuningTrainer: - """ - Helper that bridges TrainingConfig + datasets + adapters into HF Trainer. - """ - - def __init__( - self, - adapter: BaseModelAdapter, - train_dataset: Optional[BaseDataset], - eval_dataset: Optional[BaseDataset], - training_config: TrainingConfig, - trainer_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - self.adapter = adapter - self.train_dataset = train_dataset - self.eval_dataset = eval_dataset - self.training_config = training_config - self.trainer_kwargs = trainer_kwargs or {} - - def build(self) -> TrainerArtifacts: - model, processor = self.adapter.load_pretrained() - training_args = TrainingArguments( - output_dir=self.training_config.output_dir, - num_train_epochs=self.training_config.num_train_epochs, - per_device_train_batch_size=self.training_config.per_device_train_batch_size, - per_device_eval_batch_size=self.training_config.per_device_eval_batch_size, - learning_rate=self.training_config.learning_rate, - gradient_accumulation_steps=self.training_config.gradient_accumulation_steps, - lr_scheduler_type=self.training_config.lr_scheduler_type, - warmup_ratio=self.training_config.warmup_ratio, - weight_decay=self.training_config.weight_decay, - seed=self.training_config.seed, - fp16=self.training_config.fp16, - bf16=self.training_config.bf16, - report_to=self.training_config.report_to, - ) - hf_trainer = Trainer( - model=model, - args=training_args, - train_dataset=self.train_dataset, - eval_dataset=self.eval_dataset, - data_collator=collate_samples, - tokenizer=processor, - **self.trainer_kwargs, - ) - return TrainerArtifacts(trainer=hf_trainer, training_args=training_args) diff --git a/src/model_configuration/__init__.py b/src/model_configuration/__init__.py deleted file mode 100644 index 0ef48d5..0000000 --- a/src/model_configuration/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .config import ( - DatasetConfig, - EvaluationConfig, - ModelConfig, - ProjectConfig, - TrainingConfig, - VisualizationConfig, -) -from .registry import ConfigRegistry - -# ensure example configs register themselves -from . import sam2_bbox # noqa: F401 - -__all__ = [ - "DatasetConfig", - "EvaluationConfig", - "ModelConfig", - "ProjectConfig", - "TrainingConfig", - "VisualizationConfig", - "ConfigRegistry", -] diff --git a/src/model_configuration/config.py b/src/model_configuration/config.py deleted file mode 100644 index 2ef4a44..0000000 --- a/src/model_configuration/config.py +++ /dev/null @@ -1,89 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict, List, Optional - - -def _default_dict() -> Dict[str, Any]: - return {} - - -@dataclass -class DatasetConfig: - name: str - data_root: str - split: str = "test" - split_file: Optional[str] = None - annotation_file: Optional[str] = None - image_folder: Optional[str] = None - mask_folder: Optional[str] = None - extra_params: Dict[str, Any] = field(default_factory=_default_dict) - - def resolve_path(self, relative: Optional[str]) -> Optional[Path]: - if relative is None: - return None - return Path(self.data_root) / relative - - -@dataclass -class ModelConfig: - name_or_path: str - revision: Optional[str] = None - config_name: Optional[str] = None - cache_dir: Optional[str] = None - prompt_type: str = "bbox" - image_size: Optional[int] = None - pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict) - adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict) - - -@dataclass -class TrainingConfig: - output_dir: str = "./outputs" - num_train_epochs: float = 3.0 - per_device_train_batch_size: int = 1 - per_device_eval_batch_size: int = 1 - learning_rate: float = 1e-4 - weight_decay: float = 0.0 - gradient_accumulation_steps: int = 1 - lr_scheduler_type: str = "linear" - warmup_ratio: float = 0.0 - seed: int = 42 - fp16: bool = False - bf16: bool = False - report_to: List[str] = field(default_factory=lambda: ["tensorboard"]) - - -@dataclass -class EvaluationConfig: - output_dir: str = "./results" - metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"]) - thresholds: List[float] = field(default_factory=lambda: [0.5]) - max_samples: Optional[int] = None - save_predictions: bool = True - - -@dataclass -class VisualizationConfig: - num_samples: int = 20 - overlay_alpha: float = 0.6 - save_dir: str = "./results/visualizations" - - -@dataclass -class ProjectConfig: - dataset: DatasetConfig - model: ModelConfig - training: TrainingConfig = field(default_factory=TrainingConfig) - evaluation: EvaluationConfig = field(default_factory=EvaluationConfig) - visualization: VisualizationConfig = field(default_factory=VisualizationConfig) - - def to_dict(self) -> Dict[str, Any]: - return { - "dataset": self.dataset, - "model": self.model, - "training": self.training, - "evaluation": self.evaluation, - "visualization": self.visualization, - } diff --git a/src/model_configuration/registry.py b/src/model_configuration/registry.py deleted file mode 100644 index fd91327..0000000 --- a/src/model_configuration/registry.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from typing import Dict - -from .config import ProjectConfig - - -class ConfigRegistry: - """ - Stores reusable project configurations (dataset + model + training bundle). - """ - - _registry: Dict[str, ProjectConfig] = {} - - @classmethod - def register(cls, name: str, config: ProjectConfig) -> ProjectConfig: - cls._registry[name] = config - return config - - @classmethod - def get(cls, name: str) -> ProjectConfig: - if name not in cls._registry: - raise KeyError(f"ProjectConfig '{name}' is not registered.") - return cls._registry[name] - - @classmethod - def available(cls) -> Dict[str, ProjectConfig]: - return dict(cls._registry) diff --git a/src/model_configuration/sam2_bbox.py b/src/model_configuration/sam2_bbox.py deleted file mode 100644 index 1ad7101..0000000 --- a/src/model_configuration/sam2_bbox.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -from .config import ( - DatasetConfig, - EvaluationConfig, - ModelConfig, - ProjectConfig, - TrainingConfig, - VisualizationConfig, -) -from .registry import ConfigRegistry - - -SAM2_BBOX_CONFIG = ProjectConfig( - dataset=DatasetConfig( - name="crack500", - data_root="./crack500", - split="test", - split_file="test.txt", - image_folder="testcrop", - mask_folder="testdata", - ), - model=ModelConfig( - name_or_path="facebook/sam2.1-hiera-small", - prompt_type="bbox", - pipeline_kwargs={"batch_size": 1}, - ), - training=TrainingConfig( - output_dir="./outputs/sam2_bbox", - num_train_epochs=5, - per_device_train_batch_size=1, - per_device_eval_batch_size=1, - learning_rate=1e-4, - gradient_accumulation_steps=4, - lr_scheduler_type="cosine", - ), - evaluation=EvaluationConfig( - output_dir="./results/bbox_prompt", - thresholds=[0.3, 0.5, 0.75], - ), - visualization=VisualizationConfig( - save_dir="./results/bbox_prompt/visualizations", - num_samples=20, - ), -) - -ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG) diff --git a/src/point_prompt.py b/src/point_prompt.py index ed0514b..4f5da69 100644 --- a/src/point_prompt.py +++ b/src/point_prompt.py @@ -1,5 +1,5 @@ """ -点提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers) +点提示方式的 SAM2 裂缝分割实现 使用骨架采样策略,支持 1, 3, 5 个点 """ @@ -13,102 +13,103 @@ from tqdm import tqdm import json from skimage.morphology import skeletonize -from .hf_sam2_predictor import HFSam2Predictor +from sam2.build_sam import build_sam2 +from sam2.sam2_image_predictor import SAM2ImagePredictor def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray: """ 在骨架上均匀采样点 - + Args: mask: 二值掩码 (H, W),值为 0 或 255 num_points: 采样点数量 - + Returns: points: 采样点坐标 (N, 2),格式为 [x, y] """ # 确保掩码是二值的 binary_mask = (mask > 0).astype(bool) - + # 骨架化 try: skeleton = skeletonize(binary_mask) except: # 如果骨架化失败,直接使用掩码 skeleton = binary_mask - + # 获取骨架点坐标 (y, x) skeleton_coords = np.argwhere(skeleton) - + if len(skeleton_coords) == 0: # 如果没有骨架点,返回空数组 return np.array([]).reshape(0, 2) - + if len(skeleton_coords) <= num_points: # 如果骨架点数少于需要的点数,返回所有点 # 转换为 (x, y) 格式 return skeleton_coords[:, [1, 0]] - + # 均匀间隔采样 indices = np.linspace(0, len(skeleton_coords) - 1, num_points, dtype=int) sampled_coords = skeleton_coords[indices] - + # 转换为 (x, y) 格式 points = sampled_coords[:, [1, 0]] - + return points def sample_points_per_component( - mask: np.ndarray, + mask: np.ndarray, num_points_per_component: int = 3 ) -> Tuple[np.ndarray, np.ndarray]: """ 为每个连通域独立采样点 - + Args: mask: 二值掩码 (H, W) num_points_per_component: 每个连通域的点数 - + Returns: points: 所有采样点 (N, 2) labels: 点标签,全为 1(正样本) """ # 连通域分析 num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8)) - + all_points = [] - + # 跳过背景 (label 0) for region_id in range(1, num_labels): region_mask = (labels_map == region_id).astype(np.uint8) * 255 - + # 对每个连通域采样 points = sample_points_on_skeleton(region_mask, num_points_per_component) - + if len(points) > 0: all_points.append(points) - + if len(all_points) == 0: return np.array([]).reshape(0, 2), np.array([]) - + # 合并所有点 all_points = np.vstack(all_points) - + # 所有点都是正样本 point_labels = np.ones(len(all_points), dtype=np.int32) - + return all_points, point_labels def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]: """ 加载图像和掩码 - + Args: image_path: 图像路径 mask_path: 掩码路径 - + Returns: image: RGB 图像 (H, W, 3) mask: 二值掩码 (H, W) @@ -118,79 +119,79 @@ def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np if image is None: raise ValueError(f"无法加载图像: {image_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - + # 加载掩码 mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) if mask is None: raise ValueError(f"无法加载掩码: {mask_path}") - + return image, mask def predict_with_point_prompt( - predictor: HFSam2Predictor, + predictor: SAM2ImagePredictor, image: np.ndarray, points: np.ndarray, point_labels: np.ndarray = None ) -> np.ndarray: """ 使用点提示进行 SAM2 预测 - + Args: - predictor: HFSam2Predictor 实例 + predictor: SAM2ImagePredictor 实例 image: RGB 图像 (H, W, 3) points: 点坐标 (N, 2),格式为 [x, y] point_labels: 点标签 (N,),1 表示正样本,0 表示负样本 - + Returns: mask_pred: 预测掩码 (H, W) """ # 设置图像 predictor.set_image(image) - + # 如果没有点,返回空掩码 if len(points) == 0: return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - + # 默认所有点都是正样本 if point_labels is None: point_labels = np.ones(len(points), dtype=np.int32) - + # 使用点提示预测 masks, scores, logits = predictor.predict( point_coords=points, point_labels=point_labels, multimask_output=False, ) - + # 取第一个掩码(因为 multimask_output=False) mask_pred = masks[0] # shape: (H, W) - + # 转换为 0-255 mask_pred = (mask_pred * 255).astype(np.uint8) - + return mask_pred def process_test_set( data_root: str, test_file: str, - predictor: HFSam2Predictor, + predictor: SAM2ImagePredictor, output_dir: str, num_points: int = 5, per_component: bool = False ) -> List[Dict]: """ 处理整个测试集 - + Args: data_root: 数据集根目录 test_file: 测试集文件路径 (test.txt) - predictor: HFSam2Predictor 实例 + predictor: SAM2ImagePredictor 实例 output_dir: 输出目录 num_points: 采样点数量 per_component: 是否为每个连通域独立采样 - + Returns: results: 包含每个样本信息的列表 """ @@ -198,27 +199,27 @@ def process_test_set( os.makedirs(output_dir, exist_ok=True) pred_dir = os.path.join(output_dir, "predictions") os.makedirs(pred_dir, exist_ok=True) - + # 读取测试集文件 with open(test_file, 'r') as f: lines = f.readlines() - + results = [] - + print(f"开始处理 {len(lines)} 张测试图像...") print(f"采样策略: {'每连通域' if per_component else '全局'} {num_points} 个点") - + for line in tqdm(lines, desc="处理测试集"): parts = line.strip().split() if len(parts) != 2: continue - + img_rel_path, mask_rel_path = parts - + # 构建完整路径 img_path = os.path.join(data_root, img_rel_path) mask_path = os.path.join(data_root, mask_rel_path) - + # 检查文件是否存在 if not os.path.exists(img_path): print(f"警告: 图像不存在 {img_path}") @@ -226,11 +227,11 @@ def process_test_set( if not os.path.exists(mask_path): print(f"警告: 掩码不存在 {mask_path}") continue - + try: # 加载图像和掩码 image, mask_gt = load_image_and_mask(img_path, mask_path) - + # 从 GT 掩码采样点 if per_component: points, point_labels = sample_points_per_component( @@ -239,18 +240,18 @@ def process_test_set( else: points = sample_points_on_skeleton(mask_gt, num_points=num_points) point_labels = np.ones(len(points), dtype=np.int32) - + # 使用 SAM2 预测 - with torch.inference_mode(): + with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): mask_pred = predict_with_point_prompt( predictor, image, points, point_labels ) - + # 保存预测掩码 img_name = Path(img_rel_path).stem pred_path = os.path.join(pred_dir, f"{img_name}_pred.png") cv2.imwrite(pred_path, mask_pred) - + # 记录结果 results.append({ "image_path": img_rel_path, @@ -259,60 +260,62 @@ def process_test_set( "num_points": len(points), "image_shape": image.shape[:2], }) - + except Exception as e: print(f"处理失败 {img_path}: {str(e)}") continue - + # 保存结果信息 results_file = os.path.join(output_dir, "results_info.json") with open(results_file, 'w') as f: json.dump(results, f, indent=2) - + print(f"\n处理完成!共处理 {len(results)} 张图像") print(f"预测掩码保存在: {pred_dir}") print(f"结果信息保存在: {results_file}") - + return results def main(): """主函数""" import argparse - - parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估") + + parser = argparse.ArgumentParser(description="SAM2 点提示方式 - Crack500 数据集评估") parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录") parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件") - parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID") - parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录") + parser.add_argument("--checkpoint", type=str, default="./sam2/checkpoints/sam2.1_hiera_small.pt", help="模型检查点") + parser.add_argument("--model_cfg", type=str, default="sam2.1_hiera_s.yaml", help="模型配置") + parser.add_argument("--output_dir", type=str, default="./results/point_prompt", help="输出目录") parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量") parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样") - + args = parser.parse_args() - + print("=" * 60) - print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估") + print("SAM2 点提示方式 - Crack500 数据集评估") print("=" * 60) print(f"数据集根目录: {args.data_root}") print(f"测试集文件: {args.test_file}") - print(f"模型: {args.model_id}") + print(f"模型检查点: {args.checkpoint}") + print(f"模型配置: {args.model_cfg}") print(f"采样点数量: {args.num_points}") print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}") print(f"输出目录: {args.output_dir}") print("=" * 60) - + # 检查 CUDA 是否可用 if not torch.cuda.is_available(): print("警告: CUDA 不可用,将使用 CPU(速度会很慢)") else: print(f"使用 GPU: {torch.cuda.get_device_name(0)}") - - # 构建 SAM2 predictor + + # 构建 SAM2 模型 print("\n加载 SAM2 模型...") - from .hf_sam2_predictor import build_hf_sam2_predictor - predictor = build_hf_sam2_predictor(model_id=args.model_id) + sam2_model = build_sam2(args.model_cfg, args.checkpoint) + predictor = SAM2ImagePredictor(sam2_model) print("模型加载完成!") - + # 处理测试集 results = process_test_set( data_root=args.data_root, @@ -322,7 +325,7 @@ def main(): num_points=args.num_points, per_component=args.per_component ) - + print("\n" + "=" * 60) print("处理完成!接下来请运行评估脚本计算指标。") print("=" * 60) diff --git a/src/tasks/__init__.py b/src/tasks/__init__.py deleted file mode 100644 index 4cf9a61..0000000 --- a/src/tasks/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .config import TaskConfig, TaskStepConfig -from .pipeline import TaskRunner -from .registry import TaskRegistry - -# ensure built-in tasks are registered -from . import examples # noqa: F401 - -__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"] diff --git a/src/tasks/config.py b/src/tasks/config.py deleted file mode 100644 index 68f8618..0000000 --- a/src/tasks/config.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass, field -from typing import Any, Dict, List, Literal, Optional - - -TaskStepKind = Literal[ - "train", - "evaluate", - "visualize", - "bbox_inference", - "point_inference", - "legacy_evaluation", - "legacy_visualization", -] - - -@dataclass -class TaskStepConfig: - kind: TaskStepKind - dataset_split: Optional[str] = None - dataset_split_file: Optional[str] = None - limit: Optional[int] = None - eval_split: Optional[str] = None - eval_split_file: Optional[str] = None - params: Dict[str, Any] = field(default_factory=dict) - - -@dataclass -class TaskConfig: - name: str - description: str - project_config_name: str - model_key: str = "sam2" - steps: List[TaskStepConfig] = field(default_factory=list) - dataset_overrides: Dict[str, Any] = field(default_factory=dict) - model_overrides: Dict[str, Any] = field(default_factory=dict) - training_overrides: Dict[str, Any] = field(default_factory=dict) - evaluation_overrides: Dict[str, Any] = field(default_factory=dict) - visualization_overrides: Dict[str, Any] = field(default_factory=dict) diff --git a/src/tasks/examples.py b/src/tasks/examples.py deleted file mode 100644 index 676f2ea..0000000 --- a/src/tasks/examples.py +++ /dev/null @@ -1,34 +0,0 @@ -from __future__ import annotations - -from .config import TaskConfig, TaskStepConfig -from .registry import TaskRegistry - -TaskRegistry.register( - TaskConfig( - name="sam2_crack500_eval", - description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.", - project_config_name="sam2_bbox_prompt", - steps=[ - TaskStepConfig(kind="evaluate", dataset_split="test"), - TaskStepConfig(kind="visualize", dataset_split="test", limit=20), - ], - ) -) - -TaskRegistry.register( - TaskConfig( - name="sam2_crack500_train_eval", - description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.", - project_config_name="sam2_bbox_prompt", - steps=[ - TaskStepConfig( - kind="train", - dataset_split="train", - eval_split="val", - params={"num_train_epochs": 2}, - ), - TaskStepConfig(kind="evaluate", dataset_split="val", limit=32), - TaskStepConfig(kind="visualize", dataset_split="val", limit=16), - ], - ) -) diff --git a/src/tasks/io.py b/src/tasks/io.py deleted file mode 100644 index 9a9c29d..0000000 --- a/src/tasks/io.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import annotations - -import tomllib -from pathlib import Path -from typing import Any, Dict, List - -from .config import TaskConfig, TaskStepConfig - - -def load_task_from_toml(path: str | Path) -> TaskConfig: - """ - Load a TaskConfig from a TOML file. - """ - data = tomllib.loads(Path(path).read_text(encoding="utf-8")) - task_data = data.get("task", {}) - steps_data: List[Dict[str, Any]] = data.get("steps", []) - steps = [ - TaskStepConfig( - kind=step["kind"], - dataset_split=step.get("dataset_split"), - dataset_split_file=step.get("dataset_split_file"), - limit=step.get("limit"), - eval_split=step.get("eval_split"), - eval_split_file=step.get("eval_split_file"), - params=step.get("params", {}), - ) - for step in steps_data - ] - return TaskConfig( - name=task_data["name"], - description=task_data.get("description", ""), - project_config_name=task_data["project_config_name"], - model_key=task_data.get("model_key", "sam2"), - steps=steps, - dataset_overrides=task_data.get("dataset_overrides", {}), - model_overrides=task_data.get("model_overrides", {}), - training_overrides=task_data.get("training_overrides", {}), - evaluation_overrides=task_data.get("evaluation_overrides", {}), - visualization_overrides=task_data.get("visualization_overrides", {}), - ) diff --git a/src/tasks/pipeline.py b/src/tasks/pipeline.py deleted file mode 100644 index 5f1a1a8..0000000 --- a/src/tasks/pipeline.py +++ /dev/null @@ -1,264 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import fields, replace -from pathlib import Path -from typing import Any, Dict, Optional - -from ..bbox_prompt import process_test_set as bbox_process_test_set -from ..dataset import DatasetRegistry -from ..evaluation import PipelineEvaluator -from ..evaluation.utils import extract_mask_from_pipeline_output -from ..hf_sam2_predictor import build_hf_sam2_predictor -from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set -from ..legacy_visualization import ( - create_metrics_distribution_plot, - visualize_test_set as legacy_visualize_test_set, -) -from ..model import FineTuningTrainer, ModelRegistry -from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig -from ..point_prompt import process_test_set as point_process_test_set -from ..visualization import OverlayGenerator -from .config import TaskConfig, TaskStepConfig - -LOGGER = logging.getLogger(__name__) - - -def _replace_dataclass(instance, updates: Dict[str, Any]): - if not updates: - return instance - valid_fields = {f.name for f in fields(type(instance))} - filtered = {k: v for k, v in updates.items() if k in valid_fields} - if not filtered: - return instance - return replace(instance, **filtered) - - -def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig: - updates: Dict[str, Any] = {"split": split} - if split_file: - updates["split_file"] = split_file - return replace(config, **updates) - - -class TaskRunner: - """ - Sequentially executes a series of task steps (train/eval/visualize). - """ - - def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None: - self.task_config = task_config - base_project = project_config or ConfigRegistry.get(task_config.project_config_name) - if project_config is None: - base_project = self._apply_project_overrides(base_project) - self.project_config = base_project - self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model) - - def run(self) -> None: - LOGGER.info("Starting task '%s'", self.task_config.name) - for idx, step in enumerate(self.task_config.steps, start=1): - LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind) - if step.kind == "train": - self._run_train(step) - elif step.kind == "evaluate": - self._run_evaluate(step) - elif step.kind == "visualize": - self._run_visualize(step) - elif step.kind == "bbox_inference": - self._run_bbox_inference(step) - elif step.kind == "point_inference": - self._run_point_inference(step) - elif step.kind == "legacy_evaluation": - self._run_legacy_evaluation(step) - elif step.kind == "legacy_visualization": - self._run_legacy_visualization(step) - else: - raise ValueError(f"Unknown task step: {step.kind}") - - def _build_dataset(self, split: str, split_file: Optional[str]): - dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file) - return DatasetRegistry.create( - dataset_cfg.name, - config=dataset_cfg, - return_hf_dict=True, - ) - - def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig: - dataset_cfg = config.dataset - if self.task_config.dataset_overrides: - dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides) - evaluation_cfg = config.evaluation - if self.task_config.evaluation_overrides: - evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides) - visualization_cfg = config.visualization - if self.task_config.visualization_overrides: - visualization_cfg = self._apply_simple_overrides( - visualization_cfg, self.task_config.visualization_overrides - ) - model_cfg = config.model - if self.task_config.model_overrides: - model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides) - training_cfg = config.training - if self.task_config.training_overrides: - training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides) - return replace( - config, - dataset=dataset_cfg, - model=model_cfg, - training=training_cfg, - evaluation=evaluation_cfg, - visualization=visualization_cfg, - ) - - def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig: - overrides = dict(overrides) - extra_updates = overrides.pop("extra_params", {}) - merged_extra = dict(dataset_cfg.extra_params or {}) - merged_extra.update(extra_updates) - return replace(dataset_cfg, **overrides, extra_params=merged_extra) - - def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]): - overrides = dict(overrides) - return replace(cfg, **overrides) - - def _default_data_root(self) -> str: - return self.project_config.dataset.data_root - - def _default_test_file(self) -> str: - dataset_cfg = self.project_config.dataset - candidate = dataset_cfg.split_file or "test.txt" - candidate_path = Path(candidate) - if candidate_path.is_absolute(): - return str(candidate_path) - return str(Path(dataset_cfg.data_root) / candidate) - - def _default_output_dir(self) -> str: - return self.project_config.evaluation.output_dir - - def _run_train(self, step: TaskStepConfig) -> None: - train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file) - eval_dataset = None - if step.eval_split: - eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file) - trainer_builder = FineTuningTrainer( - adapter=self.adapter, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - training_config=_replace_dataclass( - self.project_config.training, - dict(step.params), - ), - ) - artifacts = trainer_builder.build() - train_result = artifacts.trainer.train() - LOGGER.info("Training result: %s", train_result) - artifacts.trainer.save_model(self.project_config.training.output_dir) - if eval_dataset: - metrics = artifacts.trainer.evaluate() - LOGGER.info("Evaluation metrics: %s", metrics) - - def _run_evaluate(self, step: TaskStepConfig) -> None: - dataset = self._build_dataset(step.dataset_split, step.dataset_split_file) - evaluation_cfg = _replace_dataclass( - self.project_config.evaluation, - {**dict(step.params), "max_samples": step.limit}, - ) - evaluator = PipelineEvaluator( - dataset=dataset, - adapter=self.adapter, - config=evaluation_cfg, - ) - summary = evaluator.run() - LOGGER.info("Evaluation summary: %s", summary) - - def _run_visualize(self, step: TaskStepConfig) -> None: - dataset = self._build_dataset(step.dataset_split, step.dataset_split_file) - vis_config = _replace_dataclass( - self.project_config.visualization, - {**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples}, - ) - overlay = OverlayGenerator(vis_config) - pipe = self.adapter.build_pipeline() - limit = min(vis_config.num_samples, len(dataset)) - for idx in range(limit): - sample = dataset[idx] - preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts")) - pred_mask = extract_mask_from_pipeline_output(preds) - mask = sample.get("labels", {}).get("mask") - overlay.visualize_sample( - image=sample["pixel_values"], - prediction=pred_mask, - mask=mask, - metadata=sample.get("metadata"), - ) - LOGGER.info("Saved overlays to %s", vis_config.save_dir) - - def _run_bbox_inference(self, step: TaskStepConfig) -> None: - params = dict(step.params) - data_root = params.get("data_root", self._default_data_root()) - test_file = params.get("test_file", self._default_test_file()) - expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05)) - output_dir = params.get("output_dir", self._default_output_dir()) - model_id = params.get("model_id", self.project_config.model.name_or_path) - predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device")) - bbox_process_test_set( - data_root=data_root, - test_file=test_file, - predictor=predictor, - output_dir=output_dir, - expand_ratio=expand_ratio, - ) - - def _run_point_inference(self, step: TaskStepConfig) -> None: - params = dict(step.params) - data_root = params.get("data_root", self._default_data_root()) - test_file = params.get("test_file", self._default_test_file()) - num_points = params.get("num_points", 5) - per_component = params.get("per_component", False) - output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf" - model_id = params.get("model_id", self.project_config.model.name_or_path) - predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device")) - point_process_test_set( - data_root=data_root, - test_file=test_file, - predictor=predictor, - output_dir=output_dir, - num_points=num_points, - per_component=per_component, - ) - - def _run_legacy_evaluation(self, step: TaskStepConfig) -> None: - params = dict(step.params) - data_root = params.get("data_root", self._default_data_root()) - test_file = params.get("test_file", self._default_test_file()) - output_dir = params.get("output_dir", self._default_output_dir()) - pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions")) - compute_skeleton = params.get("compute_skeleton", True) - legacy_evaluate_test_set( - data_root=data_root, - test_file=test_file, - pred_dir=pred_dir, - output_dir=output_dir, - compute_skeleton=compute_skeleton, - ) - - def _run_legacy_visualization(self, step: TaskStepConfig) -> None: - params = dict(step.params) - data_root = params.get("data_root", self._default_data_root()) - test_file = params.get("test_file", self._default_test_file()) - output_dir = params.get("output_dir", self._default_output_dir()) - pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions")) - num_samples = params.get("num_samples", 20) - save_all = params.get("save_all", False) - results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv")) - legacy_visualize_test_set( - data_root=data_root, - test_file=test_file, - pred_dir=pred_dir, - output_dir=output_dir, - results_csv=results_csv if Path(results_csv).exists() else None, - num_samples=num_samples, - save_all=save_all, - ) - if params.get("create_metrics_plot", True): - create_metrics_distribution_plot(results_csv, output_dir) diff --git a/src/tasks/registry.py b/src/tasks/registry.py deleted file mode 100644 index bd8a498..0000000 --- a/src/tasks/registry.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from typing import Dict - -from .config import TaskConfig - - -class TaskRegistry: - """ - Holds named task configs for reuse. - """ - - _registry: Dict[str, TaskConfig] = {} - - @classmethod - def register(cls, task: TaskConfig) -> TaskConfig: - cls._registry[task.name] = task - return task - - @classmethod - def get(cls, name: str) -> TaskConfig: - if name not in cls._registry: - raise KeyError(f"Task '{name}' is not registered.") - return cls._registry[name] - - @classmethod - def available(cls) -> Dict[str, TaskConfig]: - return dict(cls._registry) diff --git a/src/tasks/run_task.py b/src/tasks/run_task.py deleted file mode 100644 index e59a423..0000000 --- a/src/tasks/run_task.py +++ /dev/null @@ -1,44 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import Optional - -from transformers import HfArgumentParser - -from .config import TaskConfig -from .io import load_task_from_toml -from .pipeline import TaskRunner -from .registry import TaskRegistry - -# ensure built-in tasks are registered when CLI runs -from . import examples # noqa: F401 - -LOGGER = logging.getLogger(__name__) - - -@dataclass -class TaskCLIArguments: - task_name: Optional[str] = None - task_file: Optional[str] = None - - -def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig: - if not cli_args.task_name and not cli_args.task_file: - raise ValueError("Provide either --task_name or --task_file.") - if cli_args.task_file: - return load_task_from_toml(cli_args.task_file) - return TaskRegistry.get(cli_args.task_name) - - -def main() -> None: - parser = HfArgumentParser(TaskCLIArguments) - (cli_args,) = parser.parse_args_into_dataclasses() - task = resolve_task(cli_args) - runner = TaskRunner(task) - runner.run() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main() diff --git a/src/legacy_visualization.py b/src/visualization.py similarity index 96% rename from src/legacy_visualization.py rename to src/visualization.py index 71af849..4c6aedf 100644 --- a/src/legacy_visualization.py +++ b/src/visualization.py @@ -81,17 +81,17 @@ def create_comparison_figure( # 原始图像 axes[0, 0].imshow(image) - axes[0, 0].set_title("Original Image", fontsize=16) + axes[0, 0].set_title("Original Image", fontsize=12) axes[0, 0].axis('off') # GT 掩码 axes[0, 1].imshow(mask_gt, cmap='gray') - axes[0, 1].set_title("Ground Truth", fontsize=16) + axes[0, 1].set_title("Ground Truth", fontsize=12) axes[0, 1].axis('off') # 预测掩码 axes[1, 0].imshow(mask_pred, cmap='gray') - axes[1, 0].set_title("Prediction", fontsize=16) + axes[1, 0].set_title("Prediction", fontsize=12) axes[1, 0].axis('off') # 叠加可视化 @@ -110,16 +110,16 @@ def create_comparison_figure( axes[1, 1].text( 0.02, 0.98, legend_text, transform=axes[1, 1].transAxes, - fontsize=16, + fontsize=10, verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8) ) - axes[1, 1].set_title("Overlay Visualization", fontsize=16) + axes[1, 1].set_title("Overlay Visualization", fontsize=12) axes[1, 1].axis('off') - # # 设置总标题 - # if title: - # fig.suptitle(title, fontsize=16, fontweight='bold') + # 设置总标题 + if title: + fig.suptitle(title, fontsize=14, fontweight='bold') plt.tight_layout() diff --git a/src/visualization/__init__.py b/src/visualization/__init__.py deleted file mode 100644 index a4046b6..0000000 --- a/src/visualization/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .gallery import build_gallery -from .overlay import OverlayGenerator - -__all__ = ["OverlayGenerator", "build_gallery"] diff --git a/src/visualization/gallery.py b/src/visualization/gallery.py deleted file mode 100644 index e560b85..0000000 --- a/src/visualization/gallery.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Iterable - -from PIL import Image - - -def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path: - """ - Simple grid composer that stitches overlay PNGs into a gallery. - """ - image_paths = list(image_paths) - if not image_paths: - raise ValueError("No images provided for gallery.") - output_path.parent.mkdir(parents=True, exist_ok=True) - images = [Image.open(path).convert("RGB") for path in image_paths] - widths, heights = zip(*(img.size for img in images)) - cell_w = max(widths) - cell_h = max(heights) - rows = (len(images) + columns - 1) // columns - canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0)) - for idx, img in enumerate(images): - row = idx // columns - col = idx % columns - canvas.paste(img, (col * cell_w, row * cell_h)) - canvas.save(output_path) - return output_path diff --git a/src/visualization/overlay.py b/src/visualization/overlay.py deleted file mode 100644 index 18c6a08..0000000 --- a/src/visualization/overlay.py +++ /dev/null @@ -1,62 +0,0 @@ -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, Optional - -import numpy as np -from PIL import Image - -from ..model_configuration import VisualizationConfig - - -class OverlayGenerator: - """ - Turns model predictions into side-by-side overlays for quick inspection. - """ - - def __init__(self, config: VisualizationConfig) -> None: - self.config = config - Path(self.config.save_dir).mkdir(parents=True, exist_ok=True) - - def visualize_sample( - self, - image: np.ndarray, - prediction: np.ndarray, - mask: Optional[np.ndarray], - metadata: Optional[Dict[str, Any]] = None, - ) -> Path: - overlay = self._compose_overlay(image, prediction, mask) - filename = ( - metadata.get("image_name", "sample") - if metadata - else "sample" - ) - target = Path(self.config.save_dir) / f"{filename}_overlay.png" - Image.fromarray(overlay).save(target) - return target - - def _compose_overlay( - self, - image: np.ndarray, - prediction: np.ndarray, - mask: Optional[np.ndarray], - ) -> np.ndarray: - vis = image.copy() - pred_mask = self._normalize(prediction) - color = np.zeros_like(vis) - color[..., 0] = pred_mask - vis = (0.5 * vis + 0.5 * color).astype(np.uint8) - if mask is not None: - gt = self._normalize(mask) - color = np.zeros_like(vis) - color[..., 1] = gt - vis = (0.5 * vis + 0.5 * color).astype(np.uint8) - return vis - - def _normalize(self, array: np.ndarray) -> np.ndarray: - normalized = array.astype(np.float32) - normalized -= normalized.min() - denom = normalized.max() or 1.0 - normalized = normalized / denom - normalized = (normalized * 255).astype(np.uint8) - return normalized diff --git a/src/visualization/run_pipeline_vis.py b/src/visualization/run_pipeline_vis.py deleted file mode 100644 index f9427fa..0000000 --- a/src/visualization/run_pipeline_vis.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -import logging -from dataclasses import dataclass, replace -from typing import Optional - -from transformers import HfArgumentParser - -from ..dataset import DatasetRegistry -from ..evaluation.utils import extract_mask_from_pipeline_output -from ..model import ModelRegistry -from ..model_configuration import ConfigRegistry -from .overlay import OverlayGenerator - -LOGGER = logging.getLogger(__name__) - - -@dataclass -class VisualizationCLIArguments: - config_name: str = "sam2_bbox_prompt" - model_key: str = "sam2" - split: str = "test" - split_file: Optional[str] = None - num_samples: int = 20 - device: Optional[str] = None - - -def main() -> None: - parser = HfArgumentParser(VisualizationCLIArguments) - (cli_args,) = parser.parse_args_into_dataclasses() - project_config = ConfigRegistry.get(cli_args.config_name) - dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file) - dataset = DatasetRegistry.create( - dataset_cfg.name, - config=dataset_cfg, - return_hf_dict=True, - ) - adapter = ModelRegistry.create(cli_args.model_key, project_config.model) - overlay = OverlayGenerator(project_config.visualization) - pipe = adapter.build_pipeline(device=cli_args.device) - limit = min(cli_args.num_samples, len(dataset)) - for idx in range(limit): - sample = dataset[idx] - preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts")) - pred_mask = extract_mask_from_pipeline_output(preds) - mask = sample.get("labels", {}).get("mask") - overlay.visualize_sample( - image=sample["pixel_values"], - prediction=pred_mask, - mask=mask, - metadata=sample.get("metadata"), - ) - LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - main() diff --git a/tasks/bbox_eval.toml b/tasks/bbox_eval.toml deleted file mode 100644 index 9b52291..0000000 --- a/tasks/bbox_eval.toml +++ /dev/null @@ -1,34 +0,0 @@ -[task] -name = "bbox_cli_template" -description = "Run legacy bbox-prompt inference + evaluation + visualization" -project_config_name = "sam2_bbox_prompt" - -[[steps]] -kind = "bbox_inference" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -model_id = "facebook/sam2-hiera-small" -output_dir = "./results/bbox_prompt" -expand_ratio = 0.05 - -[[steps]] -kind = "legacy_evaluation" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/bbox_prompt" -pred_dir = "./results/bbox_prompt/predictions" -compute_skeleton = true - -[[steps]] -kind = "legacy_visualization" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/bbox_prompt" -pred_dir = "./results/bbox_prompt/predictions" -results_csv = "./results/bbox_prompt/evaluation_results.csv" -num_samples = 20 -save_all = false -create_metrics_plot = true diff --git a/tasks/point_eval.toml b/tasks/point_eval.toml deleted file mode 100644 index df6db58..0000000 --- a/tasks/point_eval.toml +++ /dev/null @@ -1,100 +0,0 @@ -[task] -name = "point_cli_template" -description = "Run legacy point-prompt inference/eval/visualization for multiple configs" -project_config_name = "sam2_bbox_prompt" - -# 1 point config -[[steps]] -kind = "point_inference" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -model_id = "facebook/sam2-hiera-small" -num_points = 1 -per_component = false -output_dir = "./results/point_prompt_1pts_hf" - -[[steps]] -kind = "legacy_evaluation" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/point_prompt_1pts_hf" -pred_dir = "./results/point_prompt_1pts_hf/predictions" -compute_skeleton = true - -[[steps]] -kind = "legacy_visualization" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/point_prompt_1pts_hf" -pred_dir = "./results/point_prompt_1pts_hf/predictions" -results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv" -num_samples = 10 -save_all = false -create_metrics_plot = true - -# 3 point config -[[steps]] -kind = "point_inference" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -model_id = "facebook/sam2-hiera-small" -num_points = 3 -per_component = false -output_dir = "./results/point_prompt_3pts_hf" - -[[steps]] -kind = "legacy_evaluation" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/point_prompt_3pts_hf" -pred_dir = "./results/point_prompt_3pts_hf/predictions" -compute_skeleton = true - -[[steps]] -kind = "legacy_visualization" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/point_prompt_3pts_hf" -pred_dir = "./results/point_prompt_3pts_hf/predictions" -results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv" -num_samples = 10 -save_all = false -create_metrics_plot = true - -# 5 point config -[[steps]] -kind = "point_inference" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -model_id = "facebook/sam2-hiera-small" -num_points = 5 -per_component = false -output_dir = "./results/point_prompt_5pts_hf" - -[[steps]] -kind = "legacy_evaluation" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/point_prompt_5pts_hf" -pred_dir = "./results/point_prompt_5pts_hf/predictions" -compute_skeleton = true - -[[steps]] -kind = "legacy_visualization" -[steps.params] -data_root = "./crack500" -test_file = "./crack500/test.txt" -output_dir = "./results/point_prompt_5pts_hf" -pred_dir = "./results/point_prompt_5pts_hf/predictions" -results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv" -num_samples = 10 -save_all = false -create_metrics_plot = true