revert to oldest implemention
This commit is contained in:
parent
4886fc8861
commit
5371973442
@ -1,35 +0,0 @@
|
|||||||
{
|
|
||||||
"crop_size": null,
|
|
||||||
"data_format": "channels_first",
|
|
||||||
"default_to_square": true,
|
|
||||||
"device": null,
|
|
||||||
"disable_grouping": null,
|
|
||||||
"do_center_crop": null,
|
|
||||||
"do_convert_rgb": true,
|
|
||||||
"do_normalize": false,
|
|
||||||
"do_rescale": false,
|
|
||||||
"do_resize": false,
|
|
||||||
"image_mean": [
|
|
||||||
0.485,
|
|
||||||
0.456,
|
|
||||||
0.406
|
|
||||||
],
|
|
||||||
"image_processor_type": "Sam2ImageProcessorFast",
|
|
||||||
"image_std": [
|
|
||||||
0.229,
|
|
||||||
0.224,
|
|
||||||
0.225
|
|
||||||
],
|
|
||||||
"input_data_format": null,
|
|
||||||
"mask_size": {
|
|
||||||
"height": 256,
|
|
||||||
"width": 256
|
|
||||||
},
|
|
||||||
"processor_class": "Sam2VideoProcessor",
|
|
||||||
"resample": 2,
|
|
||||||
"rescale_factor": 0.00392156862745098,
|
|
||||||
"return_tensors": null,
|
|
||||||
"size": {
|
|
||||||
"longest_edge": 1024
|
|
||||||
}
|
|
||||||
}
|
|
||||||
98
pixi.lock
generated
98
pixi.lock
generated
@ -34,6 +34,7 @@ environments:
|
|||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_ha0e22de_103.conda
|
||||||
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda
|
- conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025c-h8577fbf_0.conda
|
||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/zstd-1.5.7-hb78ec9c_6.conda
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl
|
||||||
@ -48,8 +49,10 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/51/c7/b64cae5dba3a1b138d7123ec36bb5ccd39d39939f18454407e5468f4763f/fsspec-2025.12.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f1/df/8ee1c5dd1e3308b5d5b2f2dfea323bb2f3827da8d654abb6642051199049/ipython-9.8.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl
|
||||||
@ -81,6 +84,7 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f6/74/86a07f1d0f42998ca31312f998bd3b9a7eff7f52378f4f270c8679c77fb9/nvidia_nvjitlink_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3b/6c/99acb2f9eb85c29fc6f3a7ac4dccfd992e22666dd08a642b303311326a97/nvidia_nvshmem_cu12-3.3.20-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a2/eb/86626c1bbc2edb86323022371c39aa48df6fd8b0a1647bc274577f72e90b/nvidia_nvtx_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/20/12/38679034af332785aac8774540895e234f4d07f7545804097de4b666afd8/packaging-25.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||||
@ -88,6 +92,7 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4f/87/424511bdcd02c8d7acf9f65caa09f291a519b16bd83c3fb3374b3d4ae951/pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4f/87/424511bdcd02c8d7acf9f65caa09f291a519b16bd83c3fb3374b3d4ae951/pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/12/ff/e93136587c00a543f4bc768b157fac2c47cd77b180d4f4e5c6efb6ea53a2/psutil-7.2.0-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl
|
||||||
@ -122,6 +127,7 @@ environments:
|
|||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/6d/b9/4095b668ea3678bf6a0af005527f39de12fb026516fb3df17495a733b7f8/urllib3-2.6.2-py3-none-any.whl
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl
|
||||||
|
- pypi: /home/dustella/projects/sam2
|
||||||
packages:
|
packages:
|
||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/_libgcc_mutex-0.1-conda_forge.tar.bz2
|
||||||
sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726
|
sha256: fe51de6107f9edc7aa4f786a70f4a883943bc9d39b3bb7307c04c41410990726
|
||||||
@ -144,6 +150,12 @@ packages:
|
|||||||
purls: []
|
purls: []
|
||||||
size: 23621
|
size: 23621
|
||||||
timestamp: 1650670423406
|
timestamp: 1650670423406
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/3e/38/7859ff46355f76f8d19459005ca000b6e7012f2f1ca597746cbcd1fbfe5e/antlr4-python3-runtime-4.9.3.tar.gz
|
||||||
|
name: antlr4-python3-runtime
|
||||||
|
version: 4.9.3
|
||||||
|
sha256: f224469b4168294902bb1efa80a8bf7855f24c99aef99cbefc1bcd3cce77881b
|
||||||
|
requires_dist:
|
||||||
|
- typing ; python_full_version < '3.5'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/d2/39/e7eaf1799466a4aef85b6a4fe7bd175ad2b1c6345066aa33f1f58d4b18d0/asttokens-3.0.1-py3-none-any.whl
|
||||||
name: asttokens
|
name: asttokens
|
||||||
version: 3.0.1
|
version: 3.0.1
|
||||||
@ -540,6 +552,15 @@ packages:
|
|||||||
- types-tqdm ; extra == 'typing'
|
- types-tqdm ; extra == 'typing'
|
||||||
- types-urllib3 ; extra == 'typing'
|
- types-urllib3 ; extra == 'typing'
|
||||||
requires_python: '>=3.8.0'
|
requires_python: '>=3.8.0'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-py3-none-any.whl
|
||||||
|
name: hydra-core
|
||||||
|
version: 1.3.2
|
||||||
|
sha256: fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b
|
||||||
|
requires_dist:
|
||||||
|
- omegaconf>=2.2,<2.4
|
||||||
|
- antlr4-python3-runtime==4.9.*
|
||||||
|
- packaging
|
||||||
|
- importlib-resources ; python_full_version < '3.9'
|
||||||
- conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.1-h33c6efd_0.conda
|
- conda: https://conda.anaconda.org/conda-forge/linux-64/icu-78.1-h33c6efd_0.conda
|
||||||
sha256: 7d6463d0be5092b2ae8f2fad34dc84de83eab8bd44cc0d4be8931881c973c48f
|
sha256: 7d6463d0be5092b2ae8f2fad34dc84de83eab8bd44cc0d4be8931881c973c48f
|
||||||
md5: 518e9bbbc3e3486d6a4519192ba690f8
|
md5: 518e9bbbc3e3486d6a4519192ba690f8
|
||||||
@ -623,6 +644,17 @@ packages:
|
|||||||
- sphinx<6 ; extra == 'full'
|
- sphinx<6 ; extra == 'full'
|
||||||
- tifffile ; extra == 'full'
|
- tifffile ; extra == 'full'
|
||||||
requires_python: '>=3.9'
|
requires_python: '>=3.9'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/72/73/b3d451dfc523756cf177d3ebb0af76dc7751b341c60e2a21871be400ae29/iopath-0.1.10.tar.gz
|
||||||
|
name: iopath
|
||||||
|
version: 0.1.10
|
||||||
|
sha256: 3311c16a4d9137223e20f141655759933e1eda24f8bff166af834af3c645ef01
|
||||||
|
requires_dist:
|
||||||
|
- tqdm
|
||||||
|
- typing-extensions
|
||||||
|
- portalocker
|
||||||
|
- dataclasses ; python_full_version < '3.7'
|
||||||
|
- boto3 ; extra == 'aws'
|
||||||
|
requires_python: '>=3.6'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/a3/17/20c2552266728ceba271967b87919664ecc0e33efca29c3efc6baf88c5f9/ipykernel-7.1.0-py3-none-any.whl
|
||||||
name: ipykernel
|
name: ipykernel
|
||||||
version: 7.1.0
|
version: 7.1.0
|
||||||
@ -1174,6 +1206,15 @@ packages:
|
|||||||
version: 12.8.90
|
version: 12.8.90
|
||||||
sha256: 5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f
|
sha256: 5b17e2001cc0d751a5bc2c6ec6d26ad95913324a4adb86788c944f8ce9ba441f
|
||||||
requires_python: '>=3'
|
requires_python: '>=3'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/e3/94/1843518e420fa3ed6919835845df698c7e27e183cb997394e4a670973a65/omegaconf-2.3.0-py3-none-any.whl
|
||||||
|
name: omegaconf
|
||||||
|
version: 2.3.0
|
||||||
|
sha256: 7b4df175cdb08ba400f45cae3bdcae7ba8365db4d165fc65fd04b050ab63b46b
|
||||||
|
requires_dist:
|
||||||
|
- antlr4-python3-runtime==4.9.*
|
||||||
|
- pyyaml>=5.1.0
|
||||||
|
- dataclasses ; python_full_version == '3.6.*'
|
||||||
|
requires_python: '>=3.6'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/68/1f/795e7f4aa2eacc59afa4fb61a2e35e510d06414dd5a802b51a012d691b37/opencv_python-4.12.0.88-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
|
||||||
name: opencv-python
|
name: opencv-python
|
||||||
version: 4.12.0.88
|
version: 4.12.0.88
|
||||||
@ -1355,6 +1396,25 @@ packages:
|
|||||||
- pytest>=8.4.2 ; extra == 'test'
|
- pytest>=8.4.2 ; extra == 'test'
|
||||||
- mypy>=1.18.2 ; extra == 'type'
|
- mypy>=1.18.2 ; extra == 'type'
|
||||||
requires_python: '>=3.10'
|
requires_python: '>=3.10'
|
||||||
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/4b/a6/38c8e2f318bf67d338f4d629e93b0b4b9af331f455f0390ea8ce4a099b26/portalocker-3.2.0-py3-none-any.whl
|
||||||
|
name: portalocker
|
||||||
|
version: 3.2.0
|
||||||
|
sha256: 3cdc5f565312224bc570c49337bd21428bba0ef363bbcf58b9ef4a9f11779968
|
||||||
|
requires_dist:
|
||||||
|
- pywin32>=226 ; sys_platform == 'win32'
|
||||||
|
- portalocker[tests] ; extra == 'docs'
|
||||||
|
- coverage-conditional-plugin>=0.9.0 ; extra == 'tests'
|
||||||
|
- portalocker[redis] ; extra == 'tests'
|
||||||
|
- pytest-cov>=2.8.1 ; extra == 'tests'
|
||||||
|
- pytest-mypy>=0.8.0 ; extra == 'tests'
|
||||||
|
- pytest-rerunfailures>=15.0 ; extra == 'tests'
|
||||||
|
- pytest-timeout>=2.1.0 ; extra == 'tests'
|
||||||
|
- pytest>=5.4.1 ; extra == 'tests'
|
||||||
|
- sphinx>=6.0.0 ; extra == 'tests'
|
||||||
|
- types-pywin32>=310.0.0.20250429 ; extra == 'tests'
|
||||||
|
- types-redis ; extra == 'tests'
|
||||||
|
- redis ; extra == 'redis'
|
||||||
|
requires_python: '>=3.9'
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl
|
||||||
name: prompt-toolkit
|
name: prompt-toolkit
|
||||||
version: 3.0.52
|
version: 3.0.52
|
||||||
@ -1544,6 +1604,44 @@ packages:
|
|||||||
- safetensors[testing] ; extra == 'all'
|
- safetensors[testing] ; extra == 'all'
|
||||||
- safetensors[all] ; extra == 'dev'
|
- safetensors[all] ; extra == 'dev'
|
||||||
requires_python: '>=3.9'
|
requires_python: '>=3.9'
|
||||||
|
- pypi: /home/dustella/projects/sam2
|
||||||
|
name: sam-2
|
||||||
|
version: '1.0'
|
||||||
|
sha256: f3cc0abfc266b6f57a457d26bdf52fcaa92567d8fa27ba6ff59ebe60bec0ac69
|
||||||
|
requires_dist:
|
||||||
|
- torch>=2.5.1
|
||||||
|
- torchvision>=0.20.1
|
||||||
|
- numpy>=1.24.4
|
||||||
|
- tqdm>=4.66.1
|
||||||
|
- hydra-core>=1.3.2
|
||||||
|
- iopath>=0.1.10
|
||||||
|
- pillow>=9.4.0
|
||||||
|
- matplotlib>=3.9.1 ; extra == 'notebooks'
|
||||||
|
- jupyter>=1.0.0 ; extra == 'notebooks'
|
||||||
|
- opencv-python>=4.7.0 ; extra == 'notebooks'
|
||||||
|
- eva-decord>=0.6.1 ; extra == 'notebooks'
|
||||||
|
- flask>=3.0.3 ; extra == 'interactive-demo'
|
||||||
|
- flask-cors>=5.0.0 ; extra == 'interactive-demo'
|
||||||
|
- av>=13.0.0 ; extra == 'interactive-demo'
|
||||||
|
- dataclasses-json>=0.6.7 ; extra == 'interactive-demo'
|
||||||
|
- eva-decord>=0.6.1 ; extra == 'interactive-demo'
|
||||||
|
- gunicorn>=23.0.0 ; extra == 'interactive-demo'
|
||||||
|
- imagesize>=1.4.1 ; extra == 'interactive-demo'
|
||||||
|
- pycocotools>=2.0.8 ; extra == 'interactive-demo'
|
||||||
|
- strawberry-graphql>=0.243.0 ; extra == 'interactive-demo'
|
||||||
|
- black==24.2.0 ; extra == 'dev'
|
||||||
|
- usort==1.0.2 ; extra == 'dev'
|
||||||
|
- ufmt==2.0.0b2 ; extra == 'dev'
|
||||||
|
- fvcore>=0.1.5.post20221221 ; extra == 'dev'
|
||||||
|
- pandas>=2.2.2 ; extra == 'dev'
|
||||||
|
- scikit-image>=0.24.0 ; extra == 'dev'
|
||||||
|
- tensorboard>=2.17.0 ; extra == 'dev'
|
||||||
|
- pycocotools>=2.0.8 ; extra == 'dev'
|
||||||
|
- tensordict>=0.6.0 ; extra == 'dev'
|
||||||
|
- opencv-python>=4.7.0 ; extra == 'dev'
|
||||||
|
- submitit>=1.5.1 ; extra == 'dev'
|
||||||
|
requires_python: '>=3.10.0'
|
||||||
|
editable: true
|
||||||
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f4/a2/70401a107d6d7466d64b466927e6b96fcefa99d57494b972608e2f8be50f/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
- pypi: https://mirror.nju.edu.cn/pypi/web/packages/f4/a2/70401a107d6d7466d64b466927e6b96fcefa99d57494b972608e2f8be50f/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl
|
||||||
name: scikit-image
|
name: scikit-image
|
||||||
version: 0.26.0
|
version: 0.26.0
|
||||||
|
|||||||
@ -26,3 +26,4 @@ tqdm = ">=4.65.0"
|
|||||||
pandas = ">=2.0.0"
|
pandas = ">=2.0.0"
|
||||||
transformers = ">=4.57.3, <5"
|
transformers = ">=4.57.3, <5"
|
||||||
ipykernel = ">=7.1.0, <8"
|
ipykernel = ">=7.1.0, <8"
|
||||||
|
sam-2 = { path = "/home/dustella/projects/sam2", editable = true }
|
||||||
|
|||||||
@ -1,136 +1,248 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
SAM2 边界框提示方式完整评估流程 (TaskRunner 驱动版本)
|
SAM2 边界框提示方式完整评估流程
|
||||||
|
包括:推理 -> 评估 -> 可视化
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import time
|
||||||
from dataclasses import dataclass
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from src.tasks.config import TaskConfig, TaskStepConfig
|
# 添加 src 目录到路径
|
||||||
from src.tasks.io import load_task_from_toml
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
from src.tasks.pipeline import TaskRunner
|
|
||||||
|
from bbox_prompt import process_test_set, build_sam2, SAM2ImagePredictor
|
||||||
|
from evaluation import evaluate_test_set
|
||||||
|
from visualization import visualize_test_set, create_metrics_distribution_plot
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def parse_args():
|
||||||
class BBoxCLIArgs:
|
"""解析命令行参数"""
|
||||||
data_root: str
|
|
||||||
test_file: str
|
|
||||||
model_id: str
|
|
||||||
output_dir: str
|
|
||||||
expand_ratio: float
|
|
||||||
num_vis: int
|
|
||||||
vis_all: bool
|
|
||||||
skip_inference: bool
|
|
||||||
skip_evaluation: bool
|
|
||||||
skip_visualization: bool
|
|
||||||
config_name: str
|
|
||||||
task_file: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> BBoxCLIArgs:
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="SAM2 边界框提示方式 - TaskRunner 驱动完整评估"
|
description="SAM2 边界框提示方式 - Crack500 数据集完整评估"
|
||||||
)
|
)
|
||||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
|
||||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
# 数据集参数
|
||||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
|
||||||
parser.add_argument("--output_dir", type=str, default="./results/bbox_prompt", help="输出目录")
|
|
||||||
parser.add_argument("--expand_ratio", type=float, default=0.05, help="边界框扩展比例 (0.0-1.0)")
|
|
||||||
parser.add_argument("--num_vis", type=int, default=20, help="可视化样本数量")
|
|
||||||
parser.add_argument("--vis_all", action="store_true", help="可视化所有样本")
|
|
||||||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
|
||||||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
|
||||||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--config_name",
|
"--data_root", type=str, default="./crack500",
|
||||||
type=str,
|
help="数据集根目录"
|
||||||
default="sam2_bbox_prompt",
|
|
||||||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--task_file",
|
"--test_file", type=str, default="./crack500/test.txt",
|
||||||
type=str,
|
help="测试集文件路径"
|
||||||
default=None,
|
|
||||||
help="可选:指向 TOML 任务配置(若提供则忽略其余 CLI 参数)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return BBoxCLIArgs(
|
|
||||||
data_root=args.data_root,
|
|
||||||
test_file=args.test_file,
|
|
||||||
model_id=args.model_id,
|
|
||||||
output_dir=args.output_dir,
|
|
||||||
expand_ratio=args.expand_ratio,
|
|
||||||
num_vis=args.num_vis,
|
|
||||||
vis_all=args.vis_all,
|
|
||||||
skip_inference=args.skip_inference,
|
|
||||||
skip_evaluation=args.skip_evaluation,
|
|
||||||
skip_visualization=args.skip_visualization,
|
|
||||||
config_name=args.config_name,
|
|
||||||
task_file=args.task_file,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
def build_cli_task(args: BBoxCLIArgs) -> TaskConfig:
|
parser.add_argument(
|
||||||
steps: List[TaskStepConfig] = []
|
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
|
||||||
common = {
|
help="SAM2 模型检查点路径"
|
||||||
"data_root": args.data_root,
|
)
|
||||||
"test_file": args.test_file,
|
parser.add_argument(
|
||||||
"model_id": args.model_id,
|
"--model_cfg", type=str, default="configs/sam2.1/sam2.1_hiera_s.yaml",
|
||||||
"output_dir": args.output_dir,
|
help="SAM2 模型配置文件"
|
||||||
}
|
|
||||||
if not args.skip_inference:
|
|
||||||
steps.append(
|
|
||||||
TaskStepConfig(
|
|
||||||
kind="bbox_inference",
|
|
||||||
params={**common, "expand_ratio": args.expand_ratio},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not args.skip_evaluation:
|
|
||||||
steps.append(
|
|
||||||
TaskStepConfig(
|
|
||||||
kind="legacy_evaluation",
|
|
||||||
params={
|
|
||||||
**common,
|
|
||||||
"pred_dir": f"{args.output_dir}/predictions",
|
|
||||||
"compute_skeleton": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not args.skip_visualization:
|
|
||||||
steps.append(
|
|
||||||
TaskStepConfig(
|
|
||||||
kind="legacy_visualization",
|
|
||||||
params={
|
|
||||||
**common,
|
|
||||||
"pred_dir": f"{args.output_dir}/predictions",
|
|
||||||
"results_csv": f"{args.output_dir}/evaluation_results.csv",
|
|
||||||
"num_samples": args.num_vis,
|
|
||||||
"save_all": args.vis_all,
|
|
||||||
"create_metrics_plot": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return TaskConfig(
|
|
||||||
name="bbox_cli_run",
|
|
||||||
description="Legacy bbox prompt pipeline executed via TaskRunner",
|
|
||||||
project_config_name=args.config_name,
|
|
||||||
steps=steps,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 输出参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--output_dir", type=str, default="./results/bbox_prompt",
|
||||||
|
help="输出目录"
|
||||||
|
)
|
||||||
|
|
||||||
def main() -> None:
|
# 边界框参数
|
||||||
logging.basicConfig(level=logging.INFO)
|
parser.add_argument(
|
||||||
|
"--expand_ratio", type=float, default=0.05,
|
||||||
|
help="边界框扩展比例 (0.0-1.0)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 可视化参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--num_vis", type=int, default=20,
|
||||||
|
help="可视化样本数量"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--vis_all", action="store_true",
|
||||||
|
help="可视化所有样本"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 流程控制
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_inference", action="store_true",
|
||||||
|
help="跳过推理步骤(使用已有预测结果)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_evaluation", action="store_true",
|
||||||
|
help="跳过评估步骤"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_visualization", action="store_true",
|
||||||
|
help="跳过可视化步骤"
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
if args.task_file:
|
|
||||||
task = load_task_from_toml(args.task_file)
|
print("=" * 80)
|
||||||
|
print("SAM2 边界框提示方式 - Crack500 数据集完整评估")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"数据集根目录: {args.data_root}")
|
||||||
|
print(f"测试集文件: {args.test_file}")
|
||||||
|
print(f"模型检查点: {args.checkpoint}")
|
||||||
|
print(f"模型配置: {args.model_cfg}")
|
||||||
|
print(f"边界框扩展比例: {args.expand_ratio * 100}%")
|
||||||
|
print(f"输出目录: {args.output_dir}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 检查必要文件
|
||||||
|
if not os.path.exists(args.data_root):
|
||||||
|
print(f"\n错误: 数据集目录不存在 {args.data_root}")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(args.test_file):
|
||||||
|
print(f"\n错误: 测试集文件不存在 {args.test_file}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建输出目录
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# ========== 步骤 1: 推理 ==========
|
||||||
|
if not args.skip_inference:
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤 1/3: 使用 SAM2 进行推理")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 检查模型文件
|
||||||
|
if not os.path.exists(args.checkpoint):
|
||||||
|
print(f"\n错误: 模型检查点不存在 {args.checkpoint}")
|
||||||
|
print("请先下载 SAM2 模型权重!")
|
||||||
|
print("运行: cd sam2/checkpoints && ./download_ckpts.sh")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 检查 CUDA
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||||
|
else:
|
||||||
|
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
print("\n加载 SAM2 模型...")
|
||||||
|
start_time = time.time()
|
||||||
|
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
|
||||||
|
predictor = SAM2ImagePredictor(sam2_model)
|
||||||
|
print(f"模型加载完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
# 处理测试集
|
||||||
|
print("\n开始推理...")
|
||||||
|
start_time = time.time()
|
||||||
|
results = process_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
predictor=predictor,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
expand_ratio=args.expand_ratio
|
||||||
|
)
|
||||||
|
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
print(f"成功处理 {len(results)} 张图像")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n推理过程出错: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
else:
|
else:
|
||||||
task = build_cli_task(args)
|
print("\n跳过推理步骤(使用已有预测结果)")
|
||||||
if not task.steps:
|
|
||||||
raise ValueError("No steps configured for bbox evaluation. Please enable at least one stage.")
|
# ========== 步骤 2: 评估 ==========
|
||||||
runner = TaskRunner(task)
|
if not args.skip_evaluation:
|
||||||
runner.run()
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤 2/3: 评估预测结果")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
pred_dir = os.path.join(args.output_dir, "predictions")
|
||||||
|
|
||||||
|
if not os.path.exists(pred_dir):
|
||||||
|
print(f"\n错误: 预测目录不存在 {pred_dir}")
|
||||||
|
print("请先运行推理步骤!")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
df_results = evaluate_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
pred_dir=pred_dir,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
compute_skeleton=True
|
||||||
|
)
|
||||||
|
print(f"\n评估完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n评估过程出错: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print("\n跳过评估步骤")
|
||||||
|
|
||||||
|
# ========== 步骤 3: 可视化 ==========
|
||||||
|
if not args.skip_visualization:
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("步骤 3/3: 生成可视化结果")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
pred_dir = os.path.join(args.output_dir, "predictions")
|
||||||
|
results_csv = os.path.join(args.output_dir, "evaluation_results.csv")
|
||||||
|
|
||||||
|
if not os.path.exists(pred_dir):
|
||||||
|
print(f"\n错误: 预测目录不存在 {pred_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 可视化样本
|
||||||
|
visualize_test_set(
|
||||||
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
pred_dir=pred_dir,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
results_csv=results_csv if os.path.exists(results_csv) else None,
|
||||||
|
num_samples=args.num_vis,
|
||||||
|
save_all=args.vis_all
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建指标分布图
|
||||||
|
if os.path.exists(results_csv):
|
||||||
|
create_metrics_distribution_plot(results_csv, args.output_dir)
|
||||||
|
|
||||||
|
print(f"\n可视化完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n可视化过程出错: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
print("\n跳过可视化步骤")
|
||||||
|
|
||||||
|
# ========== 完成 ==========
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("所有步骤完成!")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"\n结果保存在: {args.output_dir}")
|
||||||
|
print(f" - 预测掩码: {os.path.join(args.output_dir, 'predictions')}")
|
||||||
|
print(f" - 评估结果: {os.path.join(args.output_dir, 'evaluation_results.csv')}")
|
||||||
|
print(f" - 统计摘要: {os.path.join(args.output_dir, 'evaluation_summary.json')}")
|
||||||
|
print(f" - 可视化图像: {os.path.join(args.output_dir, 'visualizations')}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,222 +1,272 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
"""
|
"""
|
||||||
SAM2 点提示方式完整评估流程 (TaskRunner 驱动版本)
|
SAM2 点提示方式完整评估流程
|
||||||
|
支持 1, 3, 5 个点的对比实验
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
import sys
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional
|
|
||||||
|
|
||||||
import pandas as pd
|
# 添加 src 目录到路径
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
from src.tasks.config import TaskConfig, TaskStepConfig
|
from point_prompt import process_test_set, build_sam2, SAM2ImagePredictor
|
||||||
from src.tasks.io import load_task_from_toml
|
from evaluation import evaluate_test_set
|
||||||
from src.tasks.pipeline import TaskRunner
|
from visualization import visualize_test_set, create_metrics_distribution_plot
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
def run_single_experiment(
|
||||||
class PointCLIArgs:
|
data_root: str,
|
||||||
data_root: str
|
test_file: str,
|
||||||
test_file: str
|
checkpoint: str,
|
||||||
model_id: str
|
model_cfg: str,
|
||||||
point_configs: List[int]
|
num_points: int,
|
||||||
per_component: bool
|
per_component: bool = False
|
||||||
num_vis: int
|
):
|
||||||
skip_inference: bool
|
"""运行单个点数的实验"""
|
||||||
skip_evaluation: bool
|
|
||||||
skip_visualization: bool
|
|
||||||
skip_comparison: bool
|
|
||||||
comparison_dir: str
|
|
||||||
config_name: str
|
|
||||||
task_file: Optional[str]
|
|
||||||
|
|
||||||
|
# 设置输出目录
|
||||||
def parse_args() -> PointCLIArgs:
|
|
||||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - TaskRunner 驱动多点数对比实验")
|
|
||||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
|
||||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件路径")
|
|
||||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace SAM2 模型 ID")
|
|
||||||
parser.add_argument("--point_configs", type=int, nargs="+", default=[1, 3, 5], help="要测试的点数配置")
|
|
||||||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样点")
|
|
||||||
parser.add_argument("--num_vis", type=int, default=10, help="可视化样本数量")
|
|
||||||
parser.add_argument("--skip_inference", action="store_true", help="跳过推理步骤")
|
|
||||||
parser.add_argument("--skip_evaluation", action="store_true", help="跳过评估步骤")
|
|
||||||
parser.add_argument("--skip_visualization", action="store_true", help="跳过可视化步骤")
|
|
||||||
parser.add_argument("--skip_comparison", action="store_true", help="跳过实验结果对比")
|
|
||||||
parser.add_argument("--comparison_dir", type=str, default="./results", help="对比结果输出目录")
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_name",
|
|
||||||
type=str,
|
|
||||||
default="sam2_bbox_prompt",
|
|
||||||
help="ProjectConfig 名称(来自 ConfigRegistry)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--task_file",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="可选:指向 TOML 任务配置(若提供则跳过 CLI 组装步骤)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return PointCLIArgs(
|
|
||||||
data_root=args.data_root,
|
|
||||||
test_file=args.test_file,
|
|
||||||
model_id=args.model_id,
|
|
||||||
point_configs=args.point_configs,
|
|
||||||
per_component=args.per_component,
|
|
||||||
num_vis=args.num_vis,
|
|
||||||
skip_inference=args.skip_inference,
|
|
||||||
skip_evaluation=args.skip_evaluation,
|
|
||||||
skip_visualization=args.skip_visualization,
|
|
||||||
skip_comparison=args.skip_comparison,
|
|
||||||
comparison_dir=args.comparison_dir,
|
|
||||||
config_name=args.config_name,
|
|
||||||
task_file=args.task_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def default_output_dir(num_points: int, per_component: bool) -> str:
|
|
||||||
if per_component:
|
if per_component:
|
||||||
return f"./results/point_prompt_{num_points}pts_per_comp_hf"
|
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
|
||||||
return f"./results/point_prompt_{num_points}pts_hf"
|
else:
|
||||||
|
output_dir = f"./results/point_prompt_{num_points}pts"
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print(f"实验配置: {num_points} 个点 ({'每连通域' if per_component else '全局骨架'})")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
def build_task_for_points(args: PointCLIArgs, num_points: int, output_dir: str) -> TaskConfig:
|
# 加载模型
|
||||||
steps: List[TaskStepConfig] = []
|
print("\n加载 SAM2 模型...")
|
||||||
common = {
|
import torch
|
||||||
"data_root": args.data_root,
|
sam2_model = build_sam2(model_cfg, checkpoint)
|
||||||
"test_file": args.test_file,
|
predictor = SAM2ImagePredictor(sam2_model)
|
||||||
"model_id": args.model_id,
|
print("模型加载完成!")
|
||||||
"output_dir": output_dir,
|
|
||||||
}
|
# 推理
|
||||||
if not args.skip_inference:
|
print(f"\n步骤 1/3: 推理 ({num_points} 个点)")
|
||||||
steps.append(
|
start_time = time.time()
|
||||||
TaskStepConfig(
|
results = process_test_set(
|
||||||
kind="point_inference",
|
data_root=data_root,
|
||||||
params={
|
test_file=test_file,
|
||||||
**common,
|
predictor=predictor,
|
||||||
"num_points": num_points,
|
output_dir=output_dir,
|
||||||
"per_component": args.per_component,
|
num_points=num_points,
|
||||||
},
|
per_component=per_component
|
||||||
)
|
|
||||||
)
|
|
||||||
if not args.skip_evaluation:
|
|
||||||
steps.append(
|
|
||||||
TaskStepConfig(
|
|
||||||
kind="legacy_evaluation",
|
|
||||||
params={
|
|
||||||
**common,
|
|
||||||
"pred_dir": f"{output_dir}/predictions",
|
|
||||||
"compute_skeleton": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not args.skip_visualization:
|
|
||||||
steps.append(
|
|
||||||
TaskStepConfig(
|
|
||||||
kind="legacy_visualization",
|
|
||||||
params={
|
|
||||||
**common,
|
|
||||||
"pred_dir": f"{output_dir}/predictions",
|
|
||||||
"results_csv": f"{output_dir}/evaluation_results.csv",
|
|
||||||
"num_samples": args.num_vis,
|
|
||||||
"save_all": False,
|
|
||||||
"create_metrics_plot": True,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return TaskConfig(
|
|
||||||
name=f"point_cli_{num_points}",
|
|
||||||
description=f"Legacy point prompt pipeline ({num_points} pts)",
|
|
||||||
project_config_name=args.config_name,
|
|
||||||
steps=steps,
|
|
||||||
)
|
)
|
||||||
|
print(f"推理完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
# 评估
|
||||||
|
print(f"\n步骤 2/3: 评估")
|
||||||
|
pred_dir = os.path.join(output_dir, "predictions")
|
||||||
|
start_time = time.time()
|
||||||
|
df_results = evaluate_test_set(
|
||||||
|
data_root=data_root,
|
||||||
|
test_file=test_file,
|
||||||
|
pred_dir=pred_dir,
|
||||||
|
output_dir=output_dir,
|
||||||
|
compute_skeleton=True
|
||||||
|
)
|
||||||
|
print(f"评估完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
# 可视化
|
||||||
|
print(f"\n步骤 3/3: 可视化")
|
||||||
|
results_csv = os.path.join(output_dir, "evaluation_results.csv")
|
||||||
|
start_time = time.time()
|
||||||
|
visualize_test_set(
|
||||||
|
data_root=data_root,
|
||||||
|
test_file=test_file,
|
||||||
|
pred_dir=pred_dir,
|
||||||
|
output_dir=output_dir,
|
||||||
|
results_csv=results_csv,
|
||||||
|
num_samples=10,
|
||||||
|
save_all=False
|
||||||
|
)
|
||||||
|
create_metrics_distribution_plot(results_csv, output_dir)
|
||||||
|
print(f"可视化完成!耗时: {time.time() - start_time:.2f}s")
|
||||||
|
|
||||||
|
return df_results
|
||||||
|
|
||||||
|
|
||||||
def load_results_csv(output_dir: str) -> Optional[pd.DataFrame]:
|
def compare_results(results_dict: dict, output_dir: str = "./results"):
|
||||||
csv_path = Path(output_dir) / "evaluation_results.csv"
|
"""对比不同点数的结果"""
|
||||||
if not csv_path.exists():
|
import pandas as pd
|
||||||
return None
|
|
||||||
return pd.read_csv(csv_path)
|
|
||||||
|
|
||||||
|
|
||||||
def compare_results(results: Dict[int, pd.DataFrame], output_dir: str) -> None:
|
|
||||||
if not results:
|
|
||||||
return
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
|
||||||
summary_rows = []
|
|
||||||
for num_points, df in results.items():
|
|
||||||
summary_rows.append(
|
|
||||||
{
|
|
||||||
"num_points": num_points,
|
|
||||||
"iou_mean": df["iou"].mean(),
|
|
||||||
"iou_std": df["iou"].std(),
|
|
||||||
"dice_mean": df["dice"].mean(),
|
|
||||||
"dice_std": df["dice"].std(),
|
|
||||||
"f1_mean": df["f1_score"].mean(),
|
|
||||||
"f1_std": df["f1_score"].std(),
|
|
||||||
"precision_mean": df["precision"].mean(),
|
|
||||||
"recall_mean": df["recall"].mean(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
df_summary = pd.DataFrame(summary_rows).sort_values("num_points")
|
|
||||||
summary_path = Path(output_dir) / "point_comparison" / "comparison_summary.csv"
|
|
||||||
summary_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
df_summary.to_csv(summary_path, index=False)
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
metrics_to_plot = [
|
print("\n" + "=" * 80)
|
||||||
("iou_mean", "iou_std", "IoU"),
|
print("对比不同点数的性能")
|
||||||
("dice_mean", "dice_std", "Dice"),
|
print("=" * 80)
|
||||||
("f1_mean", "f1_std", "F1-Score"),
|
|
||||||
]
|
# 收集所有结果
|
||||||
|
summary = []
|
||||||
|
for num_points, df in results_dict.items():
|
||||||
|
metrics = {
|
||||||
|
'num_points': num_points,
|
||||||
|
'iou_mean': df['iou'].mean(),
|
||||||
|
'iou_std': df['iou'].std(),
|
||||||
|
'dice_mean': df['dice'].mean(),
|
||||||
|
'dice_std': df['dice'].std(),
|
||||||
|
'f1_mean': df['f1_score'].mean(),
|
||||||
|
'f1_std': df['f1_score'].std(),
|
||||||
|
'precision_mean': df['precision'].mean(),
|
||||||
|
'recall_mean': df['recall'].mean(),
|
||||||
|
}
|
||||||
|
summary.append(metrics)
|
||||||
|
|
||||||
|
df_summary = pd.DataFrame(summary)
|
||||||
|
|
||||||
|
# 打印对比表格
|
||||||
|
print("\n性能对比:")
|
||||||
|
print(df_summary.to_string(index=False))
|
||||||
|
|
||||||
|
# 保存对比结果
|
||||||
|
comparison_dir = os.path.join(output_dir, "point_comparison")
|
||||||
|
os.makedirs(comparison_dir, exist_ok=True)
|
||||||
|
|
||||||
|
csv_path = os.path.join(comparison_dir, "comparison_summary.csv")
|
||||||
|
df_summary.to_csv(csv_path, index=False)
|
||||||
|
print(f"\n对比结果已保存到: {csv_path}")
|
||||||
|
|
||||||
|
# 绘制对比图
|
||||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||||
xs = df_summary["num_points"].tolist()
|
|
||||||
for ax, (mean_col, std_col, title) in zip(axes, metrics_to_plot):
|
metrics_to_plot = [
|
||||||
ax.errorbar(
|
('iou_mean', 'iou_std', 'IoU'),
|
||||||
xs,
|
('dice_mean', 'dice_std', 'Dice'),
|
||||||
df_summary[mean_col],
|
('f1_mean', 'f1_std', 'F1-Score')
|
||||||
yerr=df_summary[std_col],
|
]
|
||||||
marker="o",
|
|
||||||
capsize=5,
|
for idx, (mean_col, std_col, title) in enumerate(metrics_to_plot):
|
||||||
linewidth=2,
|
ax = axes[idx]
|
||||||
markersize=8,
|
x = df_summary['num_points']
|
||||||
)
|
y = df_summary[mean_col]
|
||||||
ax.set_xlabel("Number of Points", fontsize=12)
|
yerr = df_summary[std_col]
|
||||||
|
|
||||||
|
ax.errorbar(x, y, yerr=yerr, marker='o', capsize=5, linewidth=2, markersize=8)
|
||||||
|
ax.set_xlabel('Number of Points', fontsize=12)
|
||||||
ax.set_ylabel(title, fontsize=12)
|
ax.set_ylabel(title, fontsize=12)
|
||||||
ax.set_title(f"{title} vs Number of Points", fontsize=14)
|
ax.set_title(f'{title} vs Number of Points', fontsize=14)
|
||||||
ax.grid(True, alpha=0.3)
|
ax.grid(True, alpha=0.3)
|
||||||
ax.set_xticks(xs)
|
ax.set_xticks(x)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
plot_path = summary_path.with_name("performance_comparison.png")
|
|
||||||
fig.savefig(plot_path, dpi=150, bbox_inches="tight")
|
plot_path = os.path.join(comparison_dir, "performance_comparison.png")
|
||||||
|
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
|
print(f"对比图已保存到: {plot_path}")
|
||||||
|
|
||||||
def main() -> None:
|
return df_summary
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
args = parse_args()
|
|
||||||
if args.task_file:
|
def main():
|
||||||
task = load_task_from_toml(args.task_file)
|
"""主函数"""
|
||||||
TaskRunner(task).run()
|
parser = argparse.ArgumentParser(
|
||||||
return
|
description="SAM2 点提示方式 - 多点数对比实验"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数据集参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_root", type=str, default="./crack500",
|
||||||
|
help="数据集根目录"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--test_file", type=str, default="./crack500/test.txt",
|
||||||
|
help="测试集文件路径"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模型参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint", type=str, default="../sam2/checkpoints/sam2.1_hiera_small.pt",
|
||||||
|
help="SAM2 模型检查点路径"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--model_cfg", type=str, default="sam2.1_hiera_s.yaml",
|
||||||
|
help="SAM2 模型配置文件"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 实验参数
|
||||||
|
parser.add_argument(
|
||||||
|
"--point_configs", type=int, nargs='+', default=[1, 3, 5],
|
||||||
|
help="要测试的点数配置"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--per_component", action="store_true",
|
||||||
|
help="为每个连通域独立采样"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--skip_comparison", action="store_true",
|
||||||
|
help="跳过对比分析"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("SAM2 点提示方式 - 多点数对比实验")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"数据集根目录: {args.data_root}")
|
||||||
|
print(f"测试集文件: {args.test_file}")
|
||||||
|
print(f"模型检查点: {args.checkpoint}")
|
||||||
|
print(f"点数配置: {args.point_configs}")
|
||||||
|
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
# 检查 CUDA
|
||||||
|
import torch
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("警告: CUDA 不可用,将使用 CPU(速度会很慢)")
|
||||||
|
else:
|
||||||
|
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
|
||||||
|
# 运行所有实验
|
||||||
|
results_dict = {}
|
||||||
|
|
||||||
comparison_data: Dict[int, pd.DataFrame] = {}
|
|
||||||
for num_points in args.point_configs:
|
for num_points in args.point_configs:
|
||||||
output_dir = default_output_dir(num_points, args.per_component)
|
try:
|
||||||
task = build_task_for_points(args, num_points, output_dir)
|
df_results = run_single_experiment(
|
||||||
if not task.steps:
|
data_root=args.data_root,
|
||||||
|
test_file=args.test_file,
|
||||||
|
checkpoint=args.checkpoint,
|
||||||
|
model_cfg=args.model_cfg,
|
||||||
|
num_points=num_points,
|
||||||
|
per_component=args.per_component
|
||||||
|
)
|
||||||
|
results_dict[num_points] = df_results
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n实验失败 ({num_points} 个点): {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
continue
|
continue
|
||||||
TaskRunner(task).run()
|
|
||||||
if not args.skip_comparison and not args.skip_evaluation:
|
# 对比分析
|
||||||
df = load_results_csv(output_dir)
|
if not args.skip_comparison and len(results_dict) > 1:
|
||||||
if df is not None:
|
try:
|
||||||
comparison_data[num_points] = df
|
compare_results(results_dict)
|
||||||
if not args.skip_comparison and comparison_data:
|
except Exception as e:
|
||||||
compare_results(comparison_data, args.comparison_dir)
|
print(f"\n对比分析失败: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# 完成
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("所有实验完成!")
|
||||||
|
print("=" * 80)
|
||||||
|
print("\n结果目录:")
|
||||||
|
for num_points in args.point_configs:
|
||||||
|
if args.per_component:
|
||||||
|
output_dir = f"./results/point_prompt_{num_points}pts_per_comp"
|
||||||
|
else:
|
||||||
|
output_dir = f"./results/point_prompt_{num_points}pts"
|
||||||
|
print(f" - {num_points} 个点: {output_dir}")
|
||||||
|
|
||||||
|
if not args.skip_comparison and len(results_dict) > 1:
|
||||||
|
print(f" - 对比分析: ./results/point_comparison")
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -1,26 +1,145 @@
|
|||||||
"""
|
"""
|
||||||
边界框提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
边界框提示方式的 SAM2 裂缝分割实现
|
||||||
从 GT 掩码中提取边界框,使用 SAM2 进行分割
|
从 GT 掩码中提取边界框,使用 SAM2 进行分割
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import List, Tuple, Dict
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import json
|
import json
|
||||||
import cv2
|
|
||||||
|
|
||||||
from .dataset.utils import extract_bboxes_from_mask, load_image_and_mask
|
from sam2.build_sam import build_sam2
|
||||||
from .hf_sam2_predictor import HFSam2Predictor
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
from .model.inference import predict_with_bbox_prompt
|
|
||||||
|
|
||||||
|
def extract_bboxes_from_mask(mask: np.ndarray, expand_ratio: float = 0.0) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
从二值掩码中提取所有连通域的边界框
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask: 二值掩码 (H, W),值为 0 或 255
|
||||||
|
expand_ratio: 边界框扩展比例,例如 0.05 表示扩展 5%
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of bounding boxes in format [x1, y1, x2, y2]
|
||||||
|
"""
|
||||||
|
# 确保掩码是二值的
|
||||||
|
binary_mask = (mask > 0).astype(np.uint8)
|
||||||
|
|
||||||
|
# 连通域分析
|
||||||
|
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||||
|
binary_mask, connectivity=8
|
||||||
|
)
|
||||||
|
|
||||||
|
bboxes = []
|
||||||
|
# 跳过背景 (label 0)
|
||||||
|
for i in range(1, num_labels):
|
||||||
|
x, y, w, h, area = stats[i]
|
||||||
|
|
||||||
|
# 过滤太小的区域(可能是噪声)
|
||||||
|
if area < 10: # 最小面积阈值
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算边界框
|
||||||
|
x1, y1 = x, y
|
||||||
|
x2, y2 = x + w, y + h
|
||||||
|
|
||||||
|
# 扩展边界框(如果需要)
|
||||||
|
if expand_ratio > 0:
|
||||||
|
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
||||||
|
w_new = w * (1 + expand_ratio)
|
||||||
|
h_new = h * (1 + expand_ratio)
|
||||||
|
x1 = max(0, int(cx - w_new / 2))
|
||||||
|
y1 = max(0, int(cy - h_new / 2))
|
||||||
|
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
||||||
|
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
||||||
|
|
||||||
|
bboxes.append(np.array([x1, y1, x2, y2]))
|
||||||
|
|
||||||
|
return bboxes
|
||||||
|
|
||||||
|
|
||||||
|
def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
加载图像和掩码
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_path: 图像路径
|
||||||
|
mask_path: 掩码路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
image: RGB 图像 (H, W, 3)
|
||||||
|
mask: 二值掩码 (H, W)
|
||||||
|
"""
|
||||||
|
# 加载图像
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
if image is None:
|
||||||
|
raise ValueError(f"无法加载图像: {image_path}")
|
||||||
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
# 加载掩码
|
||||||
|
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
||||||
|
if mask is None:
|
||||||
|
raise ValueError(f"无法加载掩码: {mask_path}")
|
||||||
|
|
||||||
|
return image, mask
|
||||||
|
|
||||||
|
|
||||||
|
def predict_with_bbox_prompt(
|
||||||
|
predictor: SAM2ImagePredictor,
|
||||||
|
image: np.ndarray,
|
||||||
|
bboxes: List[np.ndarray]
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
使用边界框提示进行 SAM2 预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictor: SAM2ImagePredictor 实例
|
||||||
|
image: RGB 图像 (H, W, 3)
|
||||||
|
bboxes: 边界框列表,每个为 [x1, y1, x2, y2]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
combined_mask: 合并后的预测掩码 (H, W)
|
||||||
|
"""
|
||||||
|
# 设置图像
|
||||||
|
predictor.set_image(image)
|
||||||
|
|
||||||
|
# 如果没有边界框,返回空掩码
|
||||||
|
if len(bboxes) == 0:
|
||||||
|
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 合并所有预测掩码
|
||||||
|
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
||||||
|
|
||||||
|
# 对每个边界框进行预测
|
||||||
|
for bbox in bboxes:
|
||||||
|
masks, scores, logits = predictor.predict(
|
||||||
|
point_coords=None,
|
||||||
|
point_labels=None,
|
||||||
|
box=bbox[None, :], # shape: (1, 4)
|
||||||
|
multimask_output=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 取第一个掩码(因为 multimask_output=False)
|
||||||
|
mask = masks[0] # shape: (H, W)
|
||||||
|
|
||||||
|
# 合并到总掩码中
|
||||||
|
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
|
||||||
|
|
||||||
|
# 转换为 0-255
|
||||||
|
combined_mask = combined_mask * 255
|
||||||
|
|
||||||
|
return combined_mask
|
||||||
|
|
||||||
|
|
||||||
def process_test_set(
|
def process_test_set(
|
||||||
data_root: str,
|
data_root: str,
|
||||||
test_file: str,
|
test_file: str,
|
||||||
predictor: HFSam2Predictor,
|
predictor: SAM2ImagePredictor,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
expand_ratio: float = 0.0
|
expand_ratio: float = 0.0
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
@ -30,7 +149,7 @@ def process_test_set(
|
|||||||
Args:
|
Args:
|
||||||
data_root: 数据集根目录
|
data_root: 数据集根目录
|
||||||
test_file: 测试集文件路径 (test.txt)
|
test_file: 测试集文件路径 (test.txt)
|
||||||
predictor: HFSam2Predictor 实例
|
predictor: SAM2ImagePredictor 实例
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
expand_ratio: 边界框扩展比例
|
expand_ratio: 边界框扩展比例
|
||||||
|
|
||||||
@ -77,7 +196,7 @@ def process_test_set(
|
|||||||
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
|
bboxes = extract_bboxes_from_mask(mask_gt, expand_ratio=expand_ratio)
|
||||||
|
|
||||||
# 使用 SAM2 预测
|
# 使用 SAM2 预测
|
||||||
with torch.inference_mode():
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||||
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
|
mask_pred = predict_with_bbox_prompt(predictor, image, bboxes)
|
||||||
|
|
||||||
# 保存预测掩码
|
# 保存预测掩码
|
||||||
@ -96,9 +215,6 @@ def process_test_set(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"处理失败 {img_path}: {str(e)}")
|
print(f"处理失败 {img_path}: {str(e)}")
|
||||||
# print stack trace
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 保存结果信息
|
# 保存结果信息
|
||||||
@ -118,20 +234,22 @@ def main():
|
|||||||
# 配置参数
|
# 配置参数
|
||||||
DATA_ROOT = "./crack500"
|
DATA_ROOT = "./crack500"
|
||||||
TEST_FILE = "./crack500/test.txt"
|
TEST_FILE = "./crack500/test.txt"
|
||||||
OUTPUT_DIR = "./results/bbox_prompt_hf"
|
OUTPUT_DIR = "./results/bbox_prompt"
|
||||||
|
|
||||||
# HuggingFace SAM2 模型
|
# SAM2 模型配置
|
||||||
MODEL_ID = "facebook/sam2-hiera-small"
|
CHECKPOINT = "./sam2/checkpoints/sam2.1_hiera_small.pt"
|
||||||
|
MODEL_CFG = "sam2.1_hiera_s.yaml"
|
||||||
|
|
||||||
# 边界框扩展比例
|
# 边界框扩展比例
|
||||||
EXPAND_RATIO = 0.05 # 5% 扩展
|
EXPAND_RATIO = 0.05 # 5% 扩展
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("SAM2 边界框提示方式 (HuggingFace) - Crack500 数据集评估")
|
print("SAM2 边界框提示方式 - Crack500 数据集评估")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"数据集根目录: {DATA_ROOT}")
|
print(f"数据集根目录: {DATA_ROOT}")
|
||||||
print(f"测试集文件: {TEST_FILE}")
|
print(f"测试集文件: {TEST_FILE}")
|
||||||
print(f"模型: {MODEL_ID}")
|
print(f"模型检查点: {CHECKPOINT}")
|
||||||
|
print(f"模型配置: {MODEL_CFG}")
|
||||||
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
|
print(f"边界框扩展比例: {EXPAND_RATIO * 100}%")
|
||||||
print(f"输出目录: {OUTPUT_DIR}")
|
print(f"输出目录: {OUTPUT_DIR}")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
@ -142,10 +260,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
|
||||||
# 构建 SAM2 predictor
|
# 构建 SAM2 模型
|
||||||
print("\n加载 SAM2 模型...")
|
print("\n加载 SAM2 模型...")
|
||||||
from .hf_sam2_predictor import build_hf_sam2_predictor
|
sam2_model = build_sam2(MODEL_CFG, CHECKPOINT)
|
||||||
predictor = build_hf_sam2_predictor(model_id=MODEL_ID)
|
predictor = SAM2ImagePredictor(sam2_model)
|
||||||
print("模型加载完成!")
|
print("模型加载完成!")
|
||||||
|
|
||||||
# 处理测试集
|
# 处理测试集
|
||||||
|
|||||||
@ -1,16 +0,0 @@
|
|||||||
from .base import BaseDataset, DatasetRecord, ModelReadySample, collate_samples
|
|
||||||
from .registry import DatasetRegistry
|
|
||||||
from .utils import extract_bboxes_from_mask, load_image_and_mask
|
|
||||||
|
|
||||||
# ensure built-in datasets register themselves
|
|
||||||
from . import crack500 # noqa: F401
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseDataset",
|
|
||||||
"DatasetRecord",
|
|
||||||
"ModelReadySample",
|
|
||||||
"collate_samples",
|
|
||||||
"DatasetRegistry",
|
|
||||||
"extract_bboxes_from_mask",
|
|
||||||
"load_image_and_mask",
|
|
||||||
]
|
|
||||||
@ -1,167 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
|
|
||||||
from ..model_configuration.config import DatasetConfig
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetRecord:
|
|
||||||
"""
|
|
||||||
Lightweight description of a single sample on disk.
|
|
||||||
"""
|
|
||||||
|
|
||||||
image_path: Path
|
|
||||||
mask_path: Optional[Path] = None
|
|
||||||
prompt_path: Optional[Path] = None
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelReadySample:
|
|
||||||
"""
|
|
||||||
Standard container that mirrors what Hugging Face pipelines expect.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pixel_values: torch.Tensor | np.ndarray
|
|
||||||
prompts: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
labels: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def to_hf_dict(self) -> Dict[str, Any]:
|
|
||||||
payload = {
|
|
||||||
"pixel_values": self.pixel_values,
|
|
||||||
"metadata": self.metadata,
|
|
||||||
}
|
|
||||||
if self.prompts:
|
|
||||||
payload["prompts"] = self.prompts
|
|
||||||
if self.labels:
|
|
||||||
payload["labels"] = self.labels
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Common dataset base class that handles record bookkeeping, IO, and
|
|
||||||
formatting tensors for Hugging Face pipelines.
|
|
||||||
"""
|
|
||||||
|
|
||||||
dataset_name: str = "base"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: DatasetConfig,
|
|
||||||
transforms: Optional[Callable[[ModelReadySample], ModelReadySample]] = None,
|
|
||||||
return_hf_dict: bool = True,
|
|
||||||
) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.transforms = transforms
|
|
||||||
self.return_hf_dict = return_hf_dict
|
|
||||||
self.records: List[DatasetRecord] = self.load_records()
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.records)
|
|
||||||
|
|
||||||
def __getitem__(self, index: int) -> Dict[str, Any] | ModelReadySample:
|
|
||||||
record = self.records[index]
|
|
||||||
sample = self.prepare_sample(record)
|
|
||||||
if self.transforms:
|
|
||||||
sample = self.transforms(sample)
|
|
||||||
return sample.to_hf_dict() if self.return_hf_dict else sample
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def load_records(self) -> List[DatasetRecord]:
|
|
||||||
"""
|
|
||||||
Scan the dataset directory / annotation files and return
|
|
||||||
structured references to each item on disk.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def prepare_sample(self, record: DatasetRecord) -> ModelReadySample:
|
|
||||||
"""
|
|
||||||
Load image/mask/prompt data from disk and wrap it inside ModelReadySample.
|
|
||||||
Subclasses can override this to implement custom augmentations or prompt generation.
|
|
||||||
"""
|
|
||||||
image = self._load_image(record.image_path)
|
|
||||||
mask = (
|
|
||||||
self._load_mask(record.mask_path)
|
|
||||||
if record.mask_path is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
prompts = self.build_prompts(record, mask)
|
|
||||||
labels = {"mask": mask} if mask is not None else {}
|
|
||||||
sample = ModelReadySample(
|
|
||||||
pixel_values=image,
|
|
||||||
prompts=prompts,
|
|
||||||
labels=labels,
|
|
||||||
metadata=record.metadata,
|
|
||||||
)
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def build_prompts(
|
|
||||||
self, record: DatasetRecord, mask: Optional[np.ndarray]
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Derive prompts from metadata or masks.
|
|
||||||
Default implementation extracts bounding boxes from masks.
|
|
||||||
"""
|
|
||||||
if mask is None:
|
|
||||||
return {}
|
|
||||||
boxes = self._mask_to_bboxes(mask)
|
|
||||||
return {"boxes": boxes}
|
|
||||||
|
|
||||||
def _load_image(self, path: Path) -> np.ndarray:
|
|
||||||
image = Image.open(path).convert("RGB")
|
|
||||||
return np.array(image)
|
|
||||||
|
|
||||||
def _load_mask(self, path: Optional[Path]) -> Optional[np.ndarray]:
|
|
||||||
if path is None:
|
|
||||||
return None
|
|
||||||
mask = Image.open(path).convert("L")
|
|
||||||
return np.array(mask)
|
|
||||||
|
|
||||||
def _mask_to_bboxes(self, mask: np.ndarray) -> List[List[int]]:
|
|
||||||
"""
|
|
||||||
Helper that mirrors the legacy bbox extraction pipeline.
|
|
||||||
"""
|
|
||||||
if mask.ndim != 2:
|
|
||||||
raise ValueError("Mask must be 2-dimensional.")
|
|
||||||
ys, xs = np.where(mask > 0)
|
|
||||||
if ys.size == 0:
|
|
||||||
return []
|
|
||||||
x_min, x_max = xs.min(), xs.max()
|
|
||||||
y_min, y_max = ys.min(), ys.max()
|
|
||||||
return [[int(x_min), int(y_min), int(x_max), int(y_max)]]
|
|
||||||
|
|
||||||
|
|
||||||
def collate_samples(batch: Iterable[Dict[str, Any] | ModelReadySample]) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Default collate_fn that merges ModelReadySample/HF dict outputs.
|
|
||||||
"""
|
|
||||||
pixel_values = []
|
|
||||||
prompts: List[Dict[str, Any]] = []
|
|
||||||
labels: List[Dict[str, Any]] = []
|
|
||||||
metadata: List[Dict[str, Any]] = []
|
|
||||||
for item in batch:
|
|
||||||
if isinstance(item, ModelReadySample):
|
|
||||||
payload = item.to_hf_dict()
|
|
||||||
else:
|
|
||||||
payload = item
|
|
||||||
pixel_values.append(payload["pixel_values"])
|
|
||||||
prompts.append(payload.get("prompts", {}))
|
|
||||||
labels.append(payload.get("labels", {}))
|
|
||||||
metadata.append(payload.get("metadata", {}))
|
|
||||||
stacked = {
|
|
||||||
"pixel_values": torch.as_tensor(np.stack(pixel_values)),
|
|
||||||
"prompts": prompts,
|
|
||||||
"labels": labels,
|
|
||||||
"metadata": metadata,
|
|
||||||
}
|
|
||||||
return stacked
|
|
||||||
@ -1,99 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .base import BaseDataset, DatasetRecord
|
|
||||||
from .registry import DatasetRegistry
|
|
||||||
from .utils import (
|
|
||||||
extract_bboxes_from_mask,
|
|
||||||
sample_points_on_skeleton,
|
|
||||||
sample_points_per_component,
|
|
||||||
)
|
|
||||||
from ..model_configuration.config import DatasetConfig
|
|
||||||
|
|
||||||
|
|
||||||
@DatasetRegistry.register("crack500")
|
|
||||||
class Crack500Dataset(BaseDataset):
|
|
||||||
"""
|
|
||||||
Reference implementation that loads Crack500 samples from an image list.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: DatasetConfig,
|
|
||||||
expand_ratio: float = 0.05,
|
|
||||||
min_area: int = 10,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
extra = dict(config.extra_params or {})
|
|
||||||
expand_ratio = float(extra.get("expand_ratio", expand_ratio))
|
|
||||||
self.prompt_mode = extra.get("prompt_mode", "bbox")
|
|
||||||
self.num_points = int(extra.get("num_points", 5))
|
|
||||||
self.per_component = bool(extra.get("per_component", False))
|
|
||||||
self.expand_ratio = expand_ratio
|
|
||||||
self.min_area = min_area
|
|
||||||
super().__init__(config, **kwargs)
|
|
||||||
|
|
||||||
def load_records(self) -> List[DatasetRecord]:
|
|
||||||
base_dir = Path(self.config.data_root)
|
|
||||||
list_file = (
|
|
||||||
Path(self.config.annotation_file)
|
|
||||||
if self.config.annotation_file
|
|
||||||
else base_dir / (self.config.split_file or "test.txt")
|
|
||||||
)
|
|
||||||
if not list_file.exists():
|
|
||||||
raise FileNotFoundError(f"Missing Crack500 split file: {list_file}")
|
|
||||||
image_dir = base_dir / (self.config.image_folder or "testcrop")
|
|
||||||
mask_dir = base_dir / (self.config.mask_folder or "testdata")
|
|
||||||
records: List[DatasetRecord] = []
|
|
||||||
with list_file.open("r", encoding="utf-8") as handle:
|
|
||||||
for line in handle:
|
|
||||||
image_name = line.strip()
|
|
||||||
if not image_name:
|
|
||||||
continue
|
|
||||||
image_path = image_dir / image_name
|
|
||||||
mask_name = image_name.replace(".jpg", ".png")
|
|
||||||
mask_path = mask_dir / mask_name
|
|
||||||
metadata = {"split": self.config.split, "image_name": image_name}
|
|
||||||
records.append(
|
|
||||||
DatasetRecord(
|
|
||||||
image_path=image_path,
|
|
||||||
mask_path=mask_path if mask_path.exists() else None,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if not records:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"No records found in {image_dir} for split {self.config.split}"
|
|
||||||
)
|
|
||||||
return records
|
|
||||||
|
|
||||||
def build_prompts(
|
|
||||||
self,
|
|
||||||
record: DatasetRecord,
|
|
||||||
mask: Optional[np.ndarray],
|
|
||||||
) -> Dict[str, List[List[int]]]:
|
|
||||||
if mask is None:
|
|
||||||
return {}
|
|
||||||
if self.prompt_mode == "point":
|
|
||||||
points, point_labels = self._build_point_prompts(mask)
|
|
||||||
if points.size == 0:
|
|
||||||
return {}
|
|
||||||
prompts: Dict[str, List[List[int]]] = {"points": points.tolist()}
|
|
||||||
if point_labels.size > 0:
|
|
||||||
prompts["point_labels"] = point_labels.tolist()
|
|
||||||
return prompts
|
|
||||||
boxes = extract_bboxes_from_mask(
|
|
||||||
mask, expand_ratio=self.expand_ratio, min_area=self.min_area
|
|
||||||
)
|
|
||||||
return {"boxes": boxes}
|
|
||||||
|
|
||||||
def _build_point_prompts(self, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
if self.per_component:
|
|
||||||
return sample_points_per_component(mask, self.num_points)
|
|
||||||
points = sample_points_on_skeleton(mask, self.num_points)
|
|
||||||
labels = np.ones(points.shape[0], dtype=np.int32)
|
|
||||||
return points, labels
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Dict, Type
|
|
||||||
|
|
||||||
from .base import BaseDataset
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetRegistry:
|
|
||||||
"""
|
|
||||||
Simple registry so configs can refer to datasets by string key.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Dict[str, Type[BaseDataset]] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name: str):
|
|
||||||
def decorator(dataset_cls: Type[BaseDataset]) -> Type[BaseDataset]:
|
|
||||||
cls._registry[name] = dataset_cls
|
|
||||||
dataset_cls.dataset_name = name
|
|
||||||
return dataset_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, name: str, *args, **kwargs) -> BaseDataset:
|
|
||||||
if name not in cls._registry:
|
|
||||||
raise KeyError(f"Dataset '{name}' is not registered.")
|
|
||||||
dataset_cls = cls._registry[name]
|
|
||||||
return dataset_cls(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available(cls) -> Dict[str, Type[BaseDataset]]:
|
|
||||||
return dict(cls._registry)
|
|
||||||
@ -1,91 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from skimage.morphology import skeletonize
|
|
||||||
|
|
||||||
|
|
||||||
def load_image_and_mask(image_path: str | Path, mask_path: str | Path) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Reads an RGB image and its mask counterpart.
|
|
||||||
"""
|
|
||||||
image_path = str(image_path)
|
|
||||||
mask_path = str(mask_path)
|
|
||||||
image = cv2.imread(image_path)
|
|
||||||
if image is None:
|
|
||||||
raise ValueError(f"无法加载图像: {image_path}")
|
|
||||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
|
||||||
if mask is None:
|
|
||||||
raise ValueError(f"无法加载掩码: {mask_path}")
|
|
||||||
return image, mask
|
|
||||||
|
|
||||||
|
|
||||||
def extract_bboxes_from_mask(
|
|
||||||
mask: np.ndarray,
|
|
||||||
expand_ratio: float = 0.0,
|
|
||||||
min_area: int = 10,
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""
|
|
||||||
Extract bounding boxes from a binary mask using connected components.
|
|
||||||
"""
|
|
||||||
binary_mask = (mask > 0).astype(np.uint8)
|
|
||||||
num_labels, _, stats, _ = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)
|
|
||||||
bboxes: List[List[int]] = []
|
|
||||||
for i in range(1, num_labels):
|
|
||||||
x, y, w, h, area = stats[i]
|
|
||||||
if area < min_area:
|
|
||||||
continue
|
|
||||||
x1, y1 = x, y
|
|
||||||
x2, y2 = x + w, y + h
|
|
||||||
if expand_ratio > 0:
|
|
||||||
cx, cy = (x1 + x2) / 2, (y1 + y2) / 2
|
|
||||||
w_new = w * (1 + expand_ratio)
|
|
||||||
h_new = h * (1 + expand_ratio)
|
|
||||||
x1 = max(0, int(cx - w_new / 2))
|
|
||||||
y1 = max(0, int(cy - h_new / 2))
|
|
||||||
x2 = min(mask.shape[1], int(cx + w_new / 2))
|
|
||||||
y2 = min(mask.shape[0], int(cy + h_new / 2))
|
|
||||||
bboxes.append([x1, y1, x2, y2])
|
|
||||||
return bboxes
|
|
||||||
|
|
||||||
|
|
||||||
def sample_points_on_skeleton(mask: np.ndarray, num_points: int) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Sample points uniformly along the mask skeleton in (x, y) order.
|
|
||||||
"""
|
|
||||||
binary_mask = (mask > 0).astype(bool)
|
|
||||||
try:
|
|
||||||
skeleton = skeletonize(binary_mask)
|
|
||||||
except Exception:
|
|
||||||
skeleton = binary_mask
|
|
||||||
coords = np.argwhere(skeleton)
|
|
||||||
if coords.size == 0:
|
|
||||||
return np.zeros((0, 2), dtype=np.int32)
|
|
||||||
if coords.shape[0] <= num_points:
|
|
||||||
points = coords[:, [1, 0]]
|
|
||||||
return points.astype(np.int32)
|
|
||||||
indices = np.linspace(0, coords.shape[0] - 1, num_points, dtype=int)
|
|
||||||
sampled = coords[indices][:, [1, 0]]
|
|
||||||
return sampled.astype(np.int32)
|
|
||||||
|
|
||||||
|
|
||||||
def sample_points_per_component(mask: np.ndarray, num_points_per_component: int) -> Tuple[np.ndarray, np.ndarray]:
|
|
||||||
"""
|
|
||||||
Sample points per connected component along each component's skeleton.
|
|
||||||
"""
|
|
||||||
num_labels, labels_map = cv2.connectedComponents((mask > 0).astype(np.uint8))
|
|
||||||
all_points = []
|
|
||||||
for region_id in range(1, num_labels):
|
|
||||||
region_mask = (labels_map == region_id).astype(np.uint8) * 255
|
|
||||||
points = sample_points_on_skeleton(region_mask, num_points_per_component)
|
|
||||||
if len(points):
|
|
||||||
all_points.append(points)
|
|
||||||
if not all_points:
|
|
||||||
return np.zeros((0, 2), dtype=np.int32), np.zeros(0, dtype=np.int32)
|
|
||||||
stacked = np.vstack(all_points)
|
|
||||||
labels = np.ones(stacked.shape[0], dtype=np.int32)
|
|
||||||
return stacked, labels
|
|
||||||
@ -1,14 +0,0 @@
|
|||||||
from .metrics import METRIC_REGISTRY, compute_dice, compute_iou, compute_precision, compute_recall
|
|
||||||
from .pipeline_eval import PipelineEvaluator
|
|
||||||
from .reporting import write_csv, write_json
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"METRIC_REGISTRY",
|
|
||||||
"PipelineEvaluator",
|
|
||||||
"compute_dice",
|
|
||||||
"compute_iou",
|
|
||||||
"compute_precision",
|
|
||||||
"compute_recall",
|
|
||||||
"write_csv",
|
|
||||||
"write_json",
|
|
||||||
]
|
|
||||||
@ -1,57 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Callable, Dict, Iterable, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def compute_iou(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
|
||||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
|
||||||
target_bin = (target > 0).astype(np.uint8)
|
|
||||||
intersection = (pred_bin & target_bin).sum()
|
|
||||||
union = (pred_bin | target_bin).sum()
|
|
||||||
return float(intersection / union) if union else 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def compute_dice(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
|
||||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
|
||||||
target_bin = (target > 0).astype(np.uint8)
|
|
||||||
intersection = (pred_bin & target_bin).sum()
|
|
||||||
total = pred_bin.sum() + target_bin.sum()
|
|
||||||
return float((2 * intersection) / total) if total else 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def compute_precision(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
|
||||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
|
||||||
target_bin = (target > 0).astype(np.uint8)
|
|
||||||
tp = (pred_bin & target_bin).sum()
|
|
||||||
fp = (pred_bin & (1 - target_bin)).sum()
|
|
||||||
return float(tp / (tp + fp)) if (tp + fp) else 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def compute_recall(pred: np.ndarray, target: np.ndarray, threshold: float = 0.5) -> float:
|
|
||||||
pred_bin = (pred >= threshold).astype(np.uint8)
|
|
||||||
target_bin = (target > 0).astype(np.uint8)
|
|
||||||
tp = (pred_bin & target_bin).sum()
|
|
||||||
fn = ((1 - pred_bin) & target_bin).sum()
|
|
||||||
return float(tp / (tp + fn)) if (tp + fn) else 0.0
|
|
||||||
|
|
||||||
|
|
||||||
MetricFn = Callable[[np.ndarray, np.ndarray, float], float]
|
|
||||||
|
|
||||||
|
|
||||||
METRIC_REGISTRY: Dict[str, MetricFn] = {
|
|
||||||
"iou": compute_iou,
|
|
||||||
"dice": compute_dice,
|
|
||||||
"precision": compute_precision,
|
|
||||||
"recall": compute_recall,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_metrics(metric_names: Iterable[str]) -> Dict[str, MetricFn]:
|
|
||||||
resolved: Dict[str, MetricFn] = {}
|
|
||||||
for name in metric_names:
|
|
||||||
if name not in METRIC_REGISTRY:
|
|
||||||
raise KeyError(f"Metric '{name}' is not registered.")
|
|
||||||
resolved[name] = METRIC_REGISTRY[name]
|
|
||||||
return resolved
|
|
||||||
@ -1,95 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from ..dataset import BaseDataset
|
|
||||||
from ..model import BaseModelAdapter
|
|
||||||
from ..model_configuration import EvaluationConfig
|
|
||||||
from .metrics import resolve_metrics
|
|
||||||
from .utils import extract_mask_from_pipeline_output
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineEvaluator:
|
|
||||||
"""
|
|
||||||
Runs a Hugging Face pipeline across a dataset and aggregates metrics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dataset: BaseDataset,
|
|
||||||
adapter: BaseModelAdapter,
|
|
||||||
config: EvaluationConfig,
|
|
||||||
) -> None:
|
|
||||||
self.dataset = dataset
|
|
||||||
self.adapter = adapter
|
|
||||||
self.config = config
|
|
||||||
self.metrics = resolve_metrics(config.metrics)
|
|
||||||
|
|
||||||
def run(self) -> Dict[str, Any]:
|
|
||||||
pipe = self.adapter.build_pipeline()
|
|
||||||
aggregated: Dict[str, List[float]] = {name: [] for name in self.metrics}
|
|
||||||
output_dir = Path(self.config.output_dir)
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
requested = self.config.max_samples or len(self.dataset)
|
|
||||||
total = min(requested, len(self.dataset))
|
|
||||||
prog_bar = tqdm(range(total), total=total)
|
|
||||||
for idx in prog_bar:
|
|
||||||
sample = self.dataset[idx]
|
|
||||||
inputs = self._build_pipeline_inputs(sample)
|
|
||||||
preds = pipe(**inputs)
|
|
||||||
labels = sample.get("labels", {})
|
|
||||||
mask = labels.get("mask")
|
|
||||||
if mask is None:
|
|
||||||
continue
|
|
||||||
prediction_mask = self._extract_mask(preds)
|
|
||||||
for metric_name, metric_fn in self.metrics.items():
|
|
||||||
for threshold in self.config.thresholds:
|
|
||||||
value = metric_fn(prediction_mask, mask, threshold)
|
|
||||||
aggregated.setdefault(f"{metric_name}@{threshold}", []).append(value)
|
|
||||||
if self.config.save_predictions:
|
|
||||||
self._write_prediction(output_dir, idx, prediction_mask, sample["metadata"])
|
|
||||||
summary = {
|
|
||||||
"metrics": {k: float(np.mean(v)) if v else 0.0 for k, v in aggregated.items()},
|
|
||||||
"config": self.config.__dict__,
|
|
||||||
"num_samples": total,
|
|
||||||
}
|
|
||||||
with (output_dir / "evaluation_summary.json").open("w", encoding="utf-8") as handle:
|
|
||||||
json.dump(summary, handle, indent=2)
|
|
||||||
return summary
|
|
||||||
|
|
||||||
def _build_pipeline_inputs(self, sample: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
inputs: Dict[str, Any] = {"images": sample["pixel_values"]}
|
|
||||||
prompts = sample.get("prompts", {})
|
|
||||||
if "boxes" in prompts and prompts["boxes"]:
|
|
||||||
inputs["boxes"] = prompts["boxes"]
|
|
||||||
if "points" in prompts and prompts["points"]:
|
|
||||||
inputs["points"] = prompts["points"]
|
|
||||||
if "point_labels" in prompts and prompts["point_labels"]:
|
|
||||||
inputs["point_labels"] = prompts["point_labels"]
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def _extract_mask(self, pipeline_output: Any) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Normalize pipeline outputs into numpy masks.
|
|
||||||
"""
|
|
||||||
return extract_mask_from_pipeline_output(pipeline_output)
|
|
||||||
|
|
||||||
def _write_prediction(
|
|
||||||
self,
|
|
||||||
output_dir: Path,
|
|
||||||
index: int,
|
|
||||||
mask: np.ndarray,
|
|
||||||
metadata: Optional[Dict[str, Any]],
|
|
||||||
) -> None:
|
|
||||||
if metadata and "image_name" in metadata:
|
|
||||||
filename = metadata["image_name"].replace(".jpg", "_pred.npy")
|
|
||||||
else:
|
|
||||||
filename = f"sample_{index:04d}_pred.npy"
|
|
||||||
target_path = output_dir / "predictions"
|
|
||||||
target_path.mkdir(parents=True, exist_ok=True)
|
|
||||||
np.save(target_path / filename, mask)
|
|
||||||
@ -1,25 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import csv
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Iterable
|
|
||||||
|
|
||||||
|
|
||||||
def write_json(summary: Dict, output_path: Path) -> None:
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
with output_path.open("w", encoding="utf-8") as handle:
|
|
||||||
json.dump(summary, handle, indent=2)
|
|
||||||
|
|
||||||
|
|
||||||
def write_csv(rows: Iterable[Dict], output_path: Path) -> None:
|
|
||||||
rows = list(rows)
|
|
||||||
if not rows:
|
|
||||||
return
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
fieldnames = sorted(rows[0].keys())
|
|
||||||
with output_path.open("w", encoding="utf-8", newline="") as handle:
|
|
||||||
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
|
||||||
writer.writeheader()
|
|
||||||
for row in rows:
|
|
||||||
writer.writerow(row)
|
|
||||||
@ -1,55 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, replace
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
from ..dataset import DatasetRegistry
|
|
||||||
from ..model import ModelRegistry
|
|
||||||
from ..model_configuration import ConfigRegistry, EvaluationConfig
|
|
||||||
from .pipeline_eval import PipelineEvaluator
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PipelineCLIArguments:
|
|
||||||
config_name: str = "sam2_bbox_prompt"
|
|
||||||
model_key: str = "sam2"
|
|
||||||
split: str = "test"
|
|
||||||
split_file: Optional[str] = None
|
|
||||||
device: Optional[str] = None
|
|
||||||
max_samples: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = HfArgumentParser(PipelineCLIArguments)
|
|
||||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
|
||||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
|
||||||
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
|
||||||
dataset = DatasetRegistry.create(
|
|
||||||
dataset_cfg.name,
|
|
||||||
config=dataset_cfg,
|
|
||||||
return_hf_dict=True,
|
|
||||||
)
|
|
||||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
|
||||||
evaluation_config = replace(
|
|
||||||
project_config.evaluation,
|
|
||||||
max_samples=cli_args.max_samples,
|
|
||||||
)
|
|
||||||
if cli_args.device:
|
|
||||||
adapter.build_pipeline(device=cli_args.device)
|
|
||||||
evaluator = PipelineEvaluator(
|
|
||||||
dataset=dataset,
|
|
||||||
adapter=adapter,
|
|
||||||
config=evaluation_config,
|
|
||||||
)
|
|
||||||
summary = evaluator.run()
|
|
||||||
LOGGER.info("Evaluation summary: %s", summary)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
main()
|
|
||||||
@ -1,16 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
def extract_mask_from_pipeline_output(pipeline_output: Any) -> np.ndarray:
|
|
||||||
if isinstance(pipeline_output, list):
|
|
||||||
pipeline_output = pipeline_output[0]
|
|
||||||
mask = pipeline_output.get("mask")
|
|
||||||
if mask is None:
|
|
||||||
raise ValueError("Pipeline output missing 'mask'.")
|
|
||||||
if isinstance(mask, np.ndarray):
|
|
||||||
return mask
|
|
||||||
return np.array(mask)
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
"""
|
|
||||||
Backward-compatible wrapper that re-exports the predictor relocated to src.model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .model.predictor import HFSam2Predictor, build_hf_sam2_predictor
|
|
||||||
|
|
||||||
__all__ = ["HFSam2Predictor", "build_hf_sam2_predictor"]
|
|
||||||
@ -1,17 +0,0 @@
|
|||||||
from .base import BaseModelAdapter
|
|
||||||
from .inference import predict_with_bbox_prompt
|
|
||||||
from .predictor import HFSam2Predictor, build_hf_sam2_predictor
|
|
||||||
from .registry import ModelRegistry
|
|
||||||
from .sam2_adapter import Sam2ModelAdapter
|
|
||||||
from .trainer import FineTuningTrainer, TrainerArtifacts
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"BaseModelAdapter",
|
|
||||||
"FineTuningTrainer",
|
|
||||||
"HFSam2Predictor",
|
|
||||||
"ModelRegistry",
|
|
||||||
"Sam2ModelAdapter",
|
|
||||||
"TrainerArtifacts",
|
|
||||||
"build_hf_sam2_predictor",
|
|
||||||
"predict_with_bbox_prompt",
|
|
||||||
]
|
|
||||||
@ -1,66 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import abc
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
from ..model_configuration import ModelConfig
|
|
||||||
|
|
||||||
|
|
||||||
class BaseModelAdapter(abc.ABC):
|
|
||||||
"""
|
|
||||||
Thin wrapper that standardizes how we instantiate models/processors/pipelines.
|
|
||||||
"""
|
|
||||||
|
|
||||||
task: str = "image-segmentation"
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig) -> None:
|
|
||||||
self.config = config
|
|
||||||
self._model = None
|
|
||||||
self._processor = None
|
|
||||||
self._pipeline = None
|
|
||||||
|
|
||||||
def load_pretrained(self):
|
|
||||||
if self._model is None or self._processor is None:
|
|
||||||
self._model, self._processor = self._load_pretrained()
|
|
||||||
return self._model, self._processor
|
|
||||||
|
|
||||||
def build_pipeline(
|
|
||||||
self,
|
|
||||||
device: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
if self._pipeline is None:
|
|
||||||
model, processor = self.load_pretrained()
|
|
||||||
pipe_kwargs = {
|
|
||||||
"task": self.task,
|
|
||||||
"model": model,
|
|
||||||
"image_processor": processor,
|
|
||||||
**self.config.pipeline_kwargs,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
if device is not None:
|
|
||||||
pipe_kwargs["device"] = device
|
|
||||||
self._pipeline = self._create_pipeline(pipe_kwargs)
|
|
||||||
return self._pipeline
|
|
||||||
|
|
||||||
async def build_pipeline_async(self, **kwargs):
|
|
||||||
"""
|
|
||||||
Async helper for future multi-device orchestration.
|
|
||||||
"""
|
|
||||||
return self.build_pipeline(**kwargs)
|
|
||||||
|
|
||||||
def save_pretrained(self, output_dir: str) -> None:
|
|
||||||
model, processor = self.load_pretrained()
|
|
||||||
model.save_pretrained(output_dir)
|
|
||||||
processor.save_pretrained(output_dir)
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def _load_pretrained(self):
|
|
||||||
"""
|
|
||||||
Return (model, processor) tuple.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _create_pipeline(self, pipe_kwargs: Dict[str, Any]):
|
|
||||||
return pipeline(**pipe_kwargs)
|
|
||||||
@ -1,32 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from .predictor import HFSam2Predictor
|
|
||||||
|
|
||||||
|
|
||||||
def predict_with_bbox_prompt(
|
|
||||||
predictor: HFSam2Predictor,
|
|
||||||
image: np.ndarray,
|
|
||||||
bboxes: List[np.ndarray],
|
|
||||||
) -> np.ndarray:
|
|
||||||
"""
|
|
||||||
Run SAM2 predictions for each bounding box and merge the masks.
|
|
||||||
"""
|
|
||||||
predictor.set_image(image)
|
|
||||||
if not bboxes:
|
|
||||||
return np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
|
||||||
combined_mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
|
|
||||||
for bbox in bboxes:
|
|
||||||
masks, _, _ = predictor.predict(
|
|
||||||
point_coords=None,
|
|
||||||
point_labels=None,
|
|
||||||
box=bbox,
|
|
||||||
multimask_output=False,
|
|
||||||
)
|
|
||||||
mask = masks[0]
|
|
||||||
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
|
|
||||||
combined_mask = combined_mask * 255
|
|
||||||
return combined_mask
|
|
||||||
@ -1,158 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
from PIL import Image
|
|
||||||
from transformers import SamModel, SamProcessor
|
|
||||||
|
|
||||||
|
|
||||||
class HFSam2Predictor:
|
|
||||||
"""
|
|
||||||
Predictor wrapper around Hugging Face SAM2 models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str = "facebook/sam2-hiera-small",
|
|
||||||
device: Optional[str] = None,
|
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
|
||||||
) -> None:
|
|
||||||
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
self.dtype = dtype
|
|
||||||
self.model = SamModel.from_pretrained(model_id).to(self.device)
|
|
||||||
self.processor = SamProcessor.from_pretrained("./configs/preprocesser.json")
|
|
||||||
self._override_processor_config()
|
|
||||||
if dtype == torch.bfloat16:
|
|
||||||
self.model = self.model.to(dtype=dtype)
|
|
||||||
self.model.eval()
|
|
||||||
self.current_image = None
|
|
||||||
self.current_image_embeddings = None
|
|
||||||
|
|
||||||
def set_image(self, image: np.ndarray) -> None:
|
|
||||||
if isinstance(image, np.ndarray):
|
|
||||||
pil_image = Image.fromarray(image.astype(np.uint8))
|
|
||||||
else:
|
|
||||||
pil_image = image
|
|
||||||
self.current_image = pil_image
|
|
||||||
with torch.inference_mode():
|
|
||||||
inputs = self.processor(images=pil_image, return_tensors="pt").to(self.device)
|
|
||||||
if self.dtype == torch.bfloat16:
|
|
||||||
inputs = {
|
|
||||||
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
|
||||||
for k, v in inputs.items()
|
|
||||||
}
|
|
||||||
self.current_image_embeddings = self.model.get_image_embeddings(inputs["pixel_values"])
|
|
||||||
|
|
||||||
def predict(
|
|
||||||
self,
|
|
||||||
point_coords: Optional[np.ndarray] = None,
|
|
||||||
point_labels: Optional[np.ndarray] = None,
|
|
||||||
box: Optional[np.ndarray] = None,
|
|
||||||
multimask_output: bool = False,
|
|
||||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
if self.current_image is None:
|
|
||||||
raise ValueError("No image set. Call set_image() first.")
|
|
||||||
input_points = self._prepare_points(point_coords)
|
|
||||||
input_labels = self._prepare_labels(point_labels)
|
|
||||||
input_boxes = self._prepare_boxes(box)
|
|
||||||
with torch.inference_mode():
|
|
||||||
inputs = self.processor(
|
|
||||||
images=self.current_image,
|
|
||||||
input_points=input_points,
|
|
||||||
input_labels=input_labels,
|
|
||||||
input_boxes=input_boxes,
|
|
||||||
return_tensors="pt",
|
|
||||||
).to(self.device)
|
|
||||||
if self.dtype == torch.bfloat16:
|
|
||||||
inputs = {
|
|
||||||
k: v.to(dtype=self.dtype) if v.dtype == torch.float32 else v
|
|
||||||
for k, v in inputs.items()
|
|
||||||
}
|
|
||||||
inputs.pop("pixel_values", None)
|
|
||||||
inputs["image_embeddings"] = self.current_image_embeddings
|
|
||||||
outputs = self.model(**inputs, multimask_output=multimask_output)
|
|
||||||
masks = self.processor.image_processor.post_process_masks(
|
|
||||||
outputs.pred_masks.float().cpu(),
|
|
||||||
inputs["original_sizes"].cpu(),
|
|
||||||
inputs["reshaped_input_sizes"].cpu(),
|
|
||||||
)[0]
|
|
||||||
scores = outputs.iou_scores.float().cpu().numpy()[0]
|
|
||||||
masks_np = (masks.squeeze(1).numpy() > 0).astype(np.uint8)
|
|
||||||
logits = outputs.pred_masks.float().cpu().numpy()[0]
|
|
||||||
return masks_np, scores, logits
|
|
||||||
|
|
||||||
def _prepare_points(self, coords: Optional[np.ndarray]):
|
|
||||||
"""
|
|
||||||
Points must be shaped (num_points, 2); wrap in outer batch dimension.
|
|
||||||
"""
|
|
||||||
if coords is None:
|
|
||||||
return None
|
|
||||||
coords_arr = np.asarray(coords)
|
|
||||||
if coords_arr.ndim == 1:
|
|
||||||
coords_arr = coords_arr[None, :]
|
|
||||||
if coords_arr.ndim != 2:
|
|
||||||
raise ValueError(f"Point coords must be 2-D, got {coords_arr.shape}.")
|
|
||||||
return [coords_arr.tolist()]
|
|
||||||
|
|
||||||
def _prepare_labels(self, labels: Optional[np.ndarray]):
|
|
||||||
"""
|
|
||||||
Labels mirror the point dimension and are shaped (num_points,).
|
|
||||||
"""
|
|
||||||
if labels is None:
|
|
||||||
return None
|
|
||||||
labels_arr = np.asarray(labels)
|
|
||||||
if labels_arr.ndim == 0:
|
|
||||||
labels_arr = labels_arr[None]
|
|
||||||
if labels_arr.ndim != 1:
|
|
||||||
raise ValueError(f"Point labels must be 1-D, got {labels_arr.shape}.")
|
|
||||||
return [labels_arr.tolist()]
|
|
||||||
|
|
||||||
def _prepare_boxes(self, boxes: Optional[np.ndarray]):
|
|
||||||
"""
|
|
||||||
HF expects boxes in shape (batch, num_boxes, 4); accept (4,), (N,4), or (B,N,4).
|
|
||||||
"""
|
|
||||||
if boxes is None:
|
|
||||||
return None
|
|
||||||
boxes_arr = np.asarray(boxes)
|
|
||||||
if boxes_arr.ndim == 1:
|
|
||||||
return [[boxes_arr.tolist()]]
|
|
||||||
if boxes_arr.ndim == 2:
|
|
||||||
return [boxes_arr.tolist()]
|
|
||||||
if boxes_arr.ndim == 3:
|
|
||||||
return boxes_arr.tolist()
|
|
||||||
raise ValueError(f"Boxes should be 1/2/3-D, got {boxes_arr.shape}.")
|
|
||||||
|
|
||||||
def _override_processor_config(self) -> None:
|
|
||||||
"""
|
|
||||||
Override processor config with local settings to avoid upstream regressions.
|
|
||||||
"""
|
|
||||||
config_path = Path(__file__).resolve().parents[2] / "configs" / "preprocesser.json"
|
|
||||||
if not config_path.exists():
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
config_dict = json.loads(config_path.read_text())
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
image_processor = getattr(self.processor, "image_processor", None)
|
|
||||||
if image_processor is None or not hasattr(image_processor, "config"):
|
|
||||||
return
|
|
||||||
# config behaves like a dict; update in-place.
|
|
||||||
try:
|
|
||||||
image_processor.config.update(config_dict)
|
|
||||||
except Exception:
|
|
||||||
for key, value in config_dict.items():
|
|
||||||
try:
|
|
||||||
setattr(image_processor.config, key, value)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
|
|
||||||
|
|
||||||
def build_hf_sam2_predictor(
|
|
||||||
model_id: str = "facebook/sam2-hiera-small",
|
|
||||||
device: Optional[str] = None,
|
|
||||||
) -> HFSam2Predictor:
|
|
||||||
return HFSam2Predictor(model_id=model_id, device=device)
|
|
||||||
@ -1,33 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Dict, Type
|
|
||||||
|
|
||||||
from ..model_configuration import ModelConfig
|
|
||||||
from .base import BaseModelAdapter
|
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
|
||||||
"""
|
|
||||||
Maps model keys to adapter classes so configs can reference them declaratively.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Dict[str, Type[BaseModelAdapter]] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name: str):
|
|
||||||
def decorator(adapter_cls: Type[BaseModelAdapter]) -> Type[BaseModelAdapter]:
|
|
||||||
cls._registry[name] = adapter_cls
|
|
||||||
return adapter_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, name: str, config: ModelConfig) -> BaseModelAdapter:
|
|
||||||
if name not in cls._registry:
|
|
||||||
raise KeyError(f"ModelAdapter '{name}' is not registered.")
|
|
||||||
adapter_cls = cls._registry[name]
|
|
||||||
return adapter_cls(config)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available(cls) -> Dict[str, Type[BaseModelAdapter]]:
|
|
||||||
return dict(cls._registry)
|
|
||||||
@ -1,35 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Tuple
|
|
||||||
|
|
||||||
from transformers import AutoModelForImageSegmentation, AutoProcessor
|
|
||||||
|
|
||||||
from ..model_configuration import ModelConfig
|
|
||||||
from .base import BaseModelAdapter
|
|
||||||
from .registry import ModelRegistry
|
|
||||||
|
|
||||||
|
|
||||||
@ModelRegistry.register("sam2")
|
|
||||||
class Sam2ModelAdapter(BaseModelAdapter):
|
|
||||||
"""
|
|
||||||
Adapter that exposes SAM2 checkpoints through the HF pipeline interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: ModelConfig) -> None:
|
|
||||||
super().__init__(config)
|
|
||||||
self.task = "image-segmentation"
|
|
||||||
|
|
||||||
def _load_pretrained(self) -> Tuple[Any, Any]:
|
|
||||||
model = AutoModelForImageSegmentation.from_pretrained(
|
|
||||||
self.config.name_or_path,
|
|
||||||
revision=self.config.revision,
|
|
||||||
cache_dir=self.config.cache_dir,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
processor = AutoProcessor.from_pretrained(
|
|
||||||
self.config.name_or_path,
|
|
||||||
revision=self.config.revision,
|
|
||||||
cache_dir=self.config.cache_dir,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
return model, processor
|
|
||||||
@ -1,88 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, replace
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
from ..dataset import DatasetRegistry
|
|
||||||
from ..model_configuration import ConfigRegistry, DatasetConfig
|
|
||||||
from .registry import ModelRegistry
|
|
||||||
from .trainer import FineTuningTrainer
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainCLIArguments:
|
|
||||||
config_name: str = "sam2_bbox_prompt"
|
|
||||||
model_key: str = "sam2"
|
|
||||||
train_split: str = "train"
|
|
||||||
eval_split: str = "val"
|
|
||||||
train_split_file: Optional[str] = None
|
|
||||||
eval_split_file: Optional[str] = None
|
|
||||||
skip_eval: bool = False
|
|
||||||
device: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def build_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
|
||||||
overrides = {}
|
|
||||||
if split:
|
|
||||||
overrides["split"] = split
|
|
||||||
if split_file:
|
|
||||||
overrides["split_file"] = split_file
|
|
||||||
return replace(config, **overrides)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = HfArgumentParser(TrainCLIArguments)
|
|
||||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
|
||||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
|
||||||
train_dataset_cfg = build_dataset(
|
|
||||||
project_config.dataset, cli_args.train_split, cli_args.train_split_file
|
|
||||||
)
|
|
||||||
eval_dataset_cfg = (
|
|
||||||
build_dataset(project_config.dataset, cli_args.eval_split, cli_args.eval_split_file)
|
|
||||||
if not cli_args.skip_eval
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = DatasetRegistry.create(
|
|
||||||
train_dataset_cfg.name,
|
|
||||||
config=train_dataset_cfg,
|
|
||||||
return_hf_dict=True,
|
|
||||||
)
|
|
||||||
eval_dataset = (
|
|
||||||
DatasetRegistry.create(
|
|
||||||
eval_dataset_cfg.name,
|
|
||||||
config=eval_dataset_cfg,
|
|
||||||
return_hf_dict=True,
|
|
||||||
)
|
|
||||||
if eval_dataset_cfg
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
|
||||||
if cli_args.device:
|
|
||||||
adapter.build_pipeline(device=cli_args.device)
|
|
||||||
|
|
||||||
trainer_builder = FineTuningTrainer(
|
|
||||||
adapter=adapter,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
training_config=project_config.training,
|
|
||||||
)
|
|
||||||
artifacts = trainer_builder.build()
|
|
||||||
LOGGER.info("Starting training with args: %s", artifacts.training_args)
|
|
||||||
train_result = artifacts.trainer.train()
|
|
||||||
LOGGER.info("Training finished: %s", train_result)
|
|
||||||
artifacts.trainer.save_model(project_config.training.output_dir)
|
|
||||||
if eval_dataset and not cli_args.skip_eval:
|
|
||||||
metrics = artifacts.trainer.evaluate()
|
|
||||||
LOGGER.info("Evaluation metrics: %s", metrics)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
main()
|
|
||||||
@ -1,64 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from transformers import Trainer, TrainingArguments
|
|
||||||
|
|
||||||
from ..dataset import BaseDataset, collate_samples
|
|
||||||
from ..model_configuration import TrainingConfig
|
|
||||||
from .base import BaseModelAdapter
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainerArtifacts:
|
|
||||||
trainer: Trainer
|
|
||||||
training_args: TrainingArguments
|
|
||||||
|
|
||||||
|
|
||||||
class FineTuningTrainer:
|
|
||||||
"""
|
|
||||||
Helper that bridges TrainingConfig + datasets + adapters into HF Trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
adapter: BaseModelAdapter,
|
|
||||||
train_dataset: Optional[BaseDataset],
|
|
||||||
eval_dataset: Optional[BaseDataset],
|
|
||||||
training_config: TrainingConfig,
|
|
||||||
trainer_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> None:
|
|
||||||
self.adapter = adapter
|
|
||||||
self.train_dataset = train_dataset
|
|
||||||
self.eval_dataset = eval_dataset
|
|
||||||
self.training_config = training_config
|
|
||||||
self.trainer_kwargs = trainer_kwargs or {}
|
|
||||||
|
|
||||||
def build(self) -> TrainerArtifacts:
|
|
||||||
model, processor = self.adapter.load_pretrained()
|
|
||||||
training_args = TrainingArguments(
|
|
||||||
output_dir=self.training_config.output_dir,
|
|
||||||
num_train_epochs=self.training_config.num_train_epochs,
|
|
||||||
per_device_train_batch_size=self.training_config.per_device_train_batch_size,
|
|
||||||
per_device_eval_batch_size=self.training_config.per_device_eval_batch_size,
|
|
||||||
learning_rate=self.training_config.learning_rate,
|
|
||||||
gradient_accumulation_steps=self.training_config.gradient_accumulation_steps,
|
|
||||||
lr_scheduler_type=self.training_config.lr_scheduler_type,
|
|
||||||
warmup_ratio=self.training_config.warmup_ratio,
|
|
||||||
weight_decay=self.training_config.weight_decay,
|
|
||||||
seed=self.training_config.seed,
|
|
||||||
fp16=self.training_config.fp16,
|
|
||||||
bf16=self.training_config.bf16,
|
|
||||||
report_to=self.training_config.report_to,
|
|
||||||
)
|
|
||||||
hf_trainer = Trainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=self.train_dataset,
|
|
||||||
eval_dataset=self.eval_dataset,
|
|
||||||
data_collator=collate_samples,
|
|
||||||
tokenizer=processor,
|
|
||||||
**self.trainer_kwargs,
|
|
||||||
)
|
|
||||||
return TrainerArtifacts(trainer=hf_trainer, training_args=training_args)
|
|
||||||
@ -1,22 +0,0 @@
|
|||||||
from .config import (
|
|
||||||
DatasetConfig,
|
|
||||||
EvaluationConfig,
|
|
||||||
ModelConfig,
|
|
||||||
ProjectConfig,
|
|
||||||
TrainingConfig,
|
|
||||||
VisualizationConfig,
|
|
||||||
)
|
|
||||||
from .registry import ConfigRegistry
|
|
||||||
|
|
||||||
# ensure example configs register themselves
|
|
||||||
from . import sam2_bbox # noqa: F401
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"DatasetConfig",
|
|
||||||
"EvaluationConfig",
|
|
||||||
"ModelConfig",
|
|
||||||
"ProjectConfig",
|
|
||||||
"TrainingConfig",
|
|
||||||
"VisualizationConfig",
|
|
||||||
"ConfigRegistry",
|
|
||||||
]
|
|
||||||
@ -1,89 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
def _default_dict() -> Dict[str, Any]:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetConfig:
|
|
||||||
name: str
|
|
||||||
data_root: str
|
|
||||||
split: str = "test"
|
|
||||||
split_file: Optional[str] = None
|
|
||||||
annotation_file: Optional[str] = None
|
|
||||||
image_folder: Optional[str] = None
|
|
||||||
mask_folder: Optional[str] = None
|
|
||||||
extra_params: Dict[str, Any] = field(default_factory=_default_dict)
|
|
||||||
|
|
||||||
def resolve_path(self, relative: Optional[str]) -> Optional[Path]:
|
|
||||||
if relative is None:
|
|
||||||
return None
|
|
||||||
return Path(self.data_root) / relative
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelConfig:
|
|
||||||
name_or_path: str
|
|
||||||
revision: Optional[str] = None
|
|
||||||
config_name: Optional[str] = None
|
|
||||||
cache_dir: Optional[str] = None
|
|
||||||
prompt_type: str = "bbox"
|
|
||||||
image_size: Optional[int] = None
|
|
||||||
pipeline_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
|
||||||
adapter_kwargs: Dict[str, Any] = field(default_factory=_default_dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TrainingConfig:
|
|
||||||
output_dir: str = "./outputs"
|
|
||||||
num_train_epochs: float = 3.0
|
|
||||||
per_device_train_batch_size: int = 1
|
|
||||||
per_device_eval_batch_size: int = 1
|
|
||||||
learning_rate: float = 1e-4
|
|
||||||
weight_decay: float = 0.0
|
|
||||||
gradient_accumulation_steps: int = 1
|
|
||||||
lr_scheduler_type: str = "linear"
|
|
||||||
warmup_ratio: float = 0.0
|
|
||||||
seed: int = 42
|
|
||||||
fp16: bool = False
|
|
||||||
bf16: bool = False
|
|
||||||
report_to: List[str] = field(default_factory=lambda: ["tensorboard"])
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvaluationConfig:
|
|
||||||
output_dir: str = "./results"
|
|
||||||
metrics: List[str] = field(default_factory=lambda: ["iou", "dice", "precision", "recall"])
|
|
||||||
thresholds: List[float] = field(default_factory=lambda: [0.5])
|
|
||||||
max_samples: Optional[int] = None
|
|
||||||
save_predictions: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VisualizationConfig:
|
|
||||||
num_samples: int = 20
|
|
||||||
overlay_alpha: float = 0.6
|
|
||||||
save_dir: str = "./results/visualizations"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProjectConfig:
|
|
||||||
dataset: DatasetConfig
|
|
||||||
model: ModelConfig
|
|
||||||
training: TrainingConfig = field(default_factory=TrainingConfig)
|
|
||||||
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
|
|
||||||
visualization: VisualizationConfig = field(default_factory=VisualizationConfig)
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"dataset": self.dataset,
|
|
||||||
"model": self.model,
|
|
||||||
"training": self.training,
|
|
||||||
"evaluation": self.evaluation,
|
|
||||||
"visualization": self.visualization,
|
|
||||||
}
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from .config import ProjectConfig
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigRegistry:
|
|
||||||
"""
|
|
||||||
Stores reusable project configurations (dataset + model + training bundle).
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Dict[str, ProjectConfig] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name: str, config: ProjectConfig) -> ProjectConfig:
|
|
||||||
cls._registry[name] = config
|
|
||||||
return config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, name: str) -> ProjectConfig:
|
|
||||||
if name not in cls._registry:
|
|
||||||
raise KeyError(f"ProjectConfig '{name}' is not registered.")
|
|
||||||
return cls._registry[name]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available(cls) -> Dict[str, ProjectConfig]:
|
|
||||||
return dict(cls._registry)
|
|
||||||
@ -1,47 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .config import (
|
|
||||||
DatasetConfig,
|
|
||||||
EvaluationConfig,
|
|
||||||
ModelConfig,
|
|
||||||
ProjectConfig,
|
|
||||||
TrainingConfig,
|
|
||||||
VisualizationConfig,
|
|
||||||
)
|
|
||||||
from .registry import ConfigRegistry
|
|
||||||
|
|
||||||
|
|
||||||
SAM2_BBOX_CONFIG = ProjectConfig(
|
|
||||||
dataset=DatasetConfig(
|
|
||||||
name="crack500",
|
|
||||||
data_root="./crack500",
|
|
||||||
split="test",
|
|
||||||
split_file="test.txt",
|
|
||||||
image_folder="testcrop",
|
|
||||||
mask_folder="testdata",
|
|
||||||
),
|
|
||||||
model=ModelConfig(
|
|
||||||
name_or_path="facebook/sam2.1-hiera-small",
|
|
||||||
prompt_type="bbox",
|
|
||||||
pipeline_kwargs={"batch_size": 1},
|
|
||||||
),
|
|
||||||
training=TrainingConfig(
|
|
||||||
output_dir="./outputs/sam2_bbox",
|
|
||||||
num_train_epochs=5,
|
|
||||||
per_device_train_batch_size=1,
|
|
||||||
per_device_eval_batch_size=1,
|
|
||||||
learning_rate=1e-4,
|
|
||||||
gradient_accumulation_steps=4,
|
|
||||||
lr_scheduler_type="cosine",
|
|
||||||
),
|
|
||||||
evaluation=EvaluationConfig(
|
|
||||||
output_dir="./results/bbox_prompt",
|
|
||||||
thresholds=[0.3, 0.5, 0.75],
|
|
||||||
),
|
|
||||||
visualization=VisualizationConfig(
|
|
||||||
save_dir="./results/bbox_prompt/visualizations",
|
|
||||||
num_samples=20,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
ConfigRegistry.register("sam2_bbox_prompt", SAM2_BBOX_CONFIG)
|
|
||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
点提示方式的 SAM2 裂缝分割实现(使用 HuggingFace Transformers)
|
点提示方式的 SAM2 裂缝分割实现
|
||||||
使用骨架采样策略,支持 1, 3, 5 个点
|
使用骨架采样策略,支持 1, 3, 5 个点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -13,7 +13,8 @@ from tqdm import tqdm
|
|||||||
import json
|
import json
|
||||||
from skimage.morphology import skeletonize
|
from skimage.morphology import skeletonize
|
||||||
|
|
||||||
from .hf_sam2_predictor import HFSam2Predictor
|
from sam2.build_sam import build_sam2
|
||||||
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
||||||
|
|
||||||
|
|
||||||
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
|
def sample_points_on_skeleton(mask: np.ndarray, num_points: int = 5) -> np.ndarray:
|
||||||
@ -128,7 +129,7 @@ def load_image_and_mask(image_path: str, mask_path: str) -> Tuple[np.ndarray, np
|
|||||||
|
|
||||||
|
|
||||||
def predict_with_point_prompt(
|
def predict_with_point_prompt(
|
||||||
predictor: HFSam2Predictor,
|
predictor: SAM2ImagePredictor,
|
||||||
image: np.ndarray,
|
image: np.ndarray,
|
||||||
points: np.ndarray,
|
points: np.ndarray,
|
||||||
point_labels: np.ndarray = None
|
point_labels: np.ndarray = None
|
||||||
@ -137,7 +138,7 @@ def predict_with_point_prompt(
|
|||||||
使用点提示进行 SAM2 预测
|
使用点提示进行 SAM2 预测
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
predictor: HFSam2Predictor 实例
|
predictor: SAM2ImagePredictor 实例
|
||||||
image: RGB 图像 (H, W, 3)
|
image: RGB 图像 (H, W, 3)
|
||||||
points: 点坐标 (N, 2),格式为 [x, y]
|
points: 点坐标 (N, 2),格式为 [x, y]
|
||||||
point_labels: 点标签 (N,),1 表示正样本,0 表示负样本
|
point_labels: 点标签 (N,),1 表示正样本,0 表示负样本
|
||||||
@ -175,7 +176,7 @@ def predict_with_point_prompt(
|
|||||||
def process_test_set(
|
def process_test_set(
|
||||||
data_root: str,
|
data_root: str,
|
||||||
test_file: str,
|
test_file: str,
|
||||||
predictor: HFSam2Predictor,
|
predictor: SAM2ImagePredictor,
|
||||||
output_dir: str,
|
output_dir: str,
|
||||||
num_points: int = 5,
|
num_points: int = 5,
|
||||||
per_component: bool = False
|
per_component: bool = False
|
||||||
@ -186,7 +187,7 @@ def process_test_set(
|
|||||||
Args:
|
Args:
|
||||||
data_root: 数据集根目录
|
data_root: 数据集根目录
|
||||||
test_file: 测试集文件路径 (test.txt)
|
test_file: 测试集文件路径 (test.txt)
|
||||||
predictor: HFSam2Predictor 实例
|
predictor: SAM2ImagePredictor 实例
|
||||||
output_dir: 输出目录
|
output_dir: 输出目录
|
||||||
num_points: 采样点数量
|
num_points: 采样点数量
|
||||||
per_component: 是否为每个连通域独立采样
|
per_component: 是否为每个连通域独立采样
|
||||||
@ -241,7 +242,7 @@ def process_test_set(
|
|||||||
point_labels = np.ones(len(points), dtype=np.int32)
|
point_labels = np.ones(len(points), dtype=np.int32)
|
||||||
|
|
||||||
# 使用 SAM2 预测
|
# 使用 SAM2 预测
|
||||||
with torch.inference_mode():
|
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
|
||||||
mask_pred = predict_with_point_prompt(
|
mask_pred = predict_with_point_prompt(
|
||||||
predictor, image, points, point_labels
|
predictor, image, points, point_labels
|
||||||
)
|
)
|
||||||
@ -280,22 +281,24 @@ def main():
|
|||||||
"""主函数"""
|
"""主函数"""
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
parser = argparse.ArgumentParser(description="SAM2 点提示方式 - Crack500 数据集评估")
|
||||||
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
parser.add_argument("--data_root", type=str, default="./crack500", help="数据集根目录")
|
||||||
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
|
parser.add_argument("--test_file", type=str, default="./crack500/test.txt", help="测试集文件")
|
||||||
parser.add_argument("--model_id", type=str, default="facebook/sam2-hiera-small", help="HuggingFace 模型 ID")
|
parser.add_argument("--checkpoint", type=str, default="./sam2/checkpoints/sam2.1_hiera_small.pt", help="模型检查点")
|
||||||
parser.add_argument("--output_dir", type=str, default="./results/point_prompt_hf", help="输出目录")
|
parser.add_argument("--model_cfg", type=str, default="sam2.1_hiera_s.yaml", help="模型配置")
|
||||||
|
parser.add_argument("--output_dir", type=str, default="./results/point_prompt", help="输出目录")
|
||||||
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
|
parser.add_argument("--num_points", type=int, default=5, choices=[1, 3, 5], help="采样点数量")
|
||||||
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
|
parser.add_argument("--per_component", action="store_true", help="为每个连通域独立采样")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print("SAM2 点提示方式 (HuggingFace) - Crack500 数据集评估")
|
print("SAM2 点提示方式 - Crack500 数据集评估")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
print(f"数据集根目录: {args.data_root}")
|
print(f"数据集根目录: {args.data_root}")
|
||||||
print(f"测试集文件: {args.test_file}")
|
print(f"测试集文件: {args.test_file}")
|
||||||
print(f"模型: {args.model_id}")
|
print(f"模型检查点: {args.checkpoint}")
|
||||||
|
print(f"模型配置: {args.model_cfg}")
|
||||||
print(f"采样点数量: {args.num_points}")
|
print(f"采样点数量: {args.num_points}")
|
||||||
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
print(f"采样策略: {'每连通域' if args.per_component else '全局骨架'}")
|
||||||
print(f"输出目录: {args.output_dir}")
|
print(f"输出目录: {args.output_dir}")
|
||||||
@ -307,10 +310,10 @@ def main():
|
|||||||
else:
|
else:
|
||||||
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
print(f"使用 GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
|
||||||
# 构建 SAM2 predictor
|
# 构建 SAM2 模型
|
||||||
print("\n加载 SAM2 模型...")
|
print("\n加载 SAM2 模型...")
|
||||||
from .hf_sam2_predictor import build_hf_sam2_predictor
|
sam2_model = build_sam2(args.model_cfg, args.checkpoint)
|
||||||
predictor = build_hf_sam2_predictor(model_id=args.model_id)
|
predictor = SAM2ImagePredictor(sam2_model)
|
||||||
print("模型加载完成!")
|
print("模型加载完成!")
|
||||||
|
|
||||||
# 处理测试集
|
# 处理测试集
|
||||||
|
|||||||
@ -1,8 +0,0 @@
|
|||||||
from .config import TaskConfig, TaskStepConfig
|
|
||||||
from .pipeline import TaskRunner
|
|
||||||
from .registry import TaskRegistry
|
|
||||||
|
|
||||||
# ensure built-in tasks are registered
|
|
||||||
from . import examples # noqa: F401
|
|
||||||
|
|
||||||
__all__ = ["TaskConfig", "TaskRunner", "TaskRegistry", "TaskStepConfig"]
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, List, Literal, Optional
|
|
||||||
|
|
||||||
|
|
||||||
TaskStepKind = Literal[
|
|
||||||
"train",
|
|
||||||
"evaluate",
|
|
||||||
"visualize",
|
|
||||||
"bbox_inference",
|
|
||||||
"point_inference",
|
|
||||||
"legacy_evaluation",
|
|
||||||
"legacy_visualization",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskStepConfig:
|
|
||||||
kind: TaskStepKind
|
|
||||||
dataset_split: Optional[str] = None
|
|
||||||
dataset_split_file: Optional[str] = None
|
|
||||||
limit: Optional[int] = None
|
|
||||||
eval_split: Optional[str] = None
|
|
||||||
eval_split_file: Optional[str] = None
|
|
||||||
params: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskConfig:
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
project_config_name: str
|
|
||||||
model_key: str = "sam2"
|
|
||||||
steps: List[TaskStepConfig] = field(default_factory=list)
|
|
||||||
dataset_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
model_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
training_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
evaluation_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
visualization_overrides: Dict[str, Any] = field(default_factory=dict)
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from .config import TaskConfig, TaskStepConfig
|
|
||||||
from .registry import TaskRegistry
|
|
||||||
|
|
||||||
TaskRegistry.register(
|
|
||||||
TaskConfig(
|
|
||||||
name="sam2_crack500_eval",
|
|
||||||
description="Evaluate SAM2 bbox prompt checkpoints on Crack500 and render overlays.",
|
|
||||||
project_config_name="sam2_bbox_prompt",
|
|
||||||
steps=[
|
|
||||||
TaskStepConfig(kind="evaluate", dataset_split="test"),
|
|
||||||
TaskStepConfig(kind="visualize", dataset_split="test", limit=20),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
TaskRegistry.register(
|
|
||||||
TaskConfig(
|
|
||||||
name="sam2_crack500_train_eval",
|
|
||||||
description="Fine-tune SAM2 on Crack500 train split, evaluate on val, then visualize results.",
|
|
||||||
project_config_name="sam2_bbox_prompt",
|
|
||||||
steps=[
|
|
||||||
TaskStepConfig(
|
|
||||||
kind="train",
|
|
||||||
dataset_split="train",
|
|
||||||
eval_split="val",
|
|
||||||
params={"num_train_epochs": 2},
|
|
||||||
),
|
|
||||||
TaskStepConfig(kind="evaluate", dataset_split="val", limit=32),
|
|
||||||
TaskStepConfig(kind="visualize", dataset_split="val", limit=16),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@ -1,40 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import tomllib
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, List
|
|
||||||
|
|
||||||
from .config import TaskConfig, TaskStepConfig
|
|
||||||
|
|
||||||
|
|
||||||
def load_task_from_toml(path: str | Path) -> TaskConfig:
|
|
||||||
"""
|
|
||||||
Load a TaskConfig from a TOML file.
|
|
||||||
"""
|
|
||||||
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
|
|
||||||
task_data = data.get("task", {})
|
|
||||||
steps_data: List[Dict[str, Any]] = data.get("steps", [])
|
|
||||||
steps = [
|
|
||||||
TaskStepConfig(
|
|
||||||
kind=step["kind"],
|
|
||||||
dataset_split=step.get("dataset_split"),
|
|
||||||
dataset_split_file=step.get("dataset_split_file"),
|
|
||||||
limit=step.get("limit"),
|
|
||||||
eval_split=step.get("eval_split"),
|
|
||||||
eval_split_file=step.get("eval_split_file"),
|
|
||||||
params=step.get("params", {}),
|
|
||||||
)
|
|
||||||
for step in steps_data
|
|
||||||
]
|
|
||||||
return TaskConfig(
|
|
||||||
name=task_data["name"],
|
|
||||||
description=task_data.get("description", ""),
|
|
||||||
project_config_name=task_data["project_config_name"],
|
|
||||||
model_key=task_data.get("model_key", "sam2"),
|
|
||||||
steps=steps,
|
|
||||||
dataset_overrides=task_data.get("dataset_overrides", {}),
|
|
||||||
model_overrides=task_data.get("model_overrides", {}),
|
|
||||||
training_overrides=task_data.get("training_overrides", {}),
|
|
||||||
evaluation_overrides=task_data.get("evaluation_overrides", {}),
|
|
||||||
visualization_overrides=task_data.get("visualization_overrides", {}),
|
|
||||||
)
|
|
||||||
@ -1,264 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import fields, replace
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
from ..bbox_prompt import process_test_set as bbox_process_test_set
|
|
||||||
from ..dataset import DatasetRegistry
|
|
||||||
from ..evaluation import PipelineEvaluator
|
|
||||||
from ..evaluation.utils import extract_mask_from_pipeline_output
|
|
||||||
from ..hf_sam2_predictor import build_hf_sam2_predictor
|
|
||||||
from ..legacy_evaluation import evaluate_test_set as legacy_evaluate_test_set
|
|
||||||
from ..legacy_visualization import (
|
|
||||||
create_metrics_distribution_plot,
|
|
||||||
visualize_test_set as legacy_visualize_test_set,
|
|
||||||
)
|
|
||||||
from ..model import FineTuningTrainer, ModelRegistry
|
|
||||||
from ..model_configuration import ConfigRegistry, DatasetConfig, ProjectConfig
|
|
||||||
from ..point_prompt import process_test_set as point_process_test_set
|
|
||||||
from ..visualization import OverlayGenerator
|
|
||||||
from .config import TaskConfig, TaskStepConfig
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_dataclass(instance, updates: Dict[str, Any]):
|
|
||||||
if not updates:
|
|
||||||
return instance
|
|
||||||
valid_fields = {f.name for f in fields(type(instance))}
|
|
||||||
filtered = {k: v for k, v in updates.items() if k in valid_fields}
|
|
||||||
if not filtered:
|
|
||||||
return instance
|
|
||||||
return replace(instance, **filtered)
|
|
||||||
|
|
||||||
|
|
||||||
def _override_dataset(config: DatasetConfig, split: str, split_file: Optional[str]) -> DatasetConfig:
|
|
||||||
updates: Dict[str, Any] = {"split": split}
|
|
||||||
if split_file:
|
|
||||||
updates["split_file"] = split_file
|
|
||||||
return replace(config, **updates)
|
|
||||||
|
|
||||||
|
|
||||||
class TaskRunner:
|
|
||||||
"""
|
|
||||||
Sequentially executes a series of task steps (train/eval/visualize).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, task_config: TaskConfig, project_config: Optional[ProjectConfig] = None) -> None:
|
|
||||||
self.task_config = task_config
|
|
||||||
base_project = project_config or ConfigRegistry.get(task_config.project_config_name)
|
|
||||||
if project_config is None:
|
|
||||||
base_project = self._apply_project_overrides(base_project)
|
|
||||||
self.project_config = base_project
|
|
||||||
self.adapter = ModelRegistry.create(task_config.model_key, self.project_config.model)
|
|
||||||
|
|
||||||
def run(self) -> None:
|
|
||||||
LOGGER.info("Starting task '%s'", self.task_config.name)
|
|
||||||
for idx, step in enumerate(self.task_config.steps, start=1):
|
|
||||||
LOGGER.info("Running step %d/%d: %s", idx, len(self.task_config.steps), step.kind)
|
|
||||||
if step.kind == "train":
|
|
||||||
self._run_train(step)
|
|
||||||
elif step.kind == "evaluate":
|
|
||||||
self._run_evaluate(step)
|
|
||||||
elif step.kind == "visualize":
|
|
||||||
self._run_visualize(step)
|
|
||||||
elif step.kind == "bbox_inference":
|
|
||||||
self._run_bbox_inference(step)
|
|
||||||
elif step.kind == "point_inference":
|
|
||||||
self._run_point_inference(step)
|
|
||||||
elif step.kind == "legacy_evaluation":
|
|
||||||
self._run_legacy_evaluation(step)
|
|
||||||
elif step.kind == "legacy_visualization":
|
|
||||||
self._run_legacy_visualization(step)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown task step: {step.kind}")
|
|
||||||
|
|
||||||
def _build_dataset(self, split: str, split_file: Optional[str]):
|
|
||||||
dataset_cfg = _override_dataset(self.project_config.dataset, split, split_file)
|
|
||||||
return DatasetRegistry.create(
|
|
||||||
dataset_cfg.name,
|
|
||||||
config=dataset_cfg,
|
|
||||||
return_hf_dict=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_project_overrides(self, config: ProjectConfig) -> ProjectConfig:
|
|
||||||
dataset_cfg = config.dataset
|
|
||||||
if self.task_config.dataset_overrides:
|
|
||||||
dataset_cfg = self._apply_dataset_overrides(dataset_cfg, self.task_config.dataset_overrides)
|
|
||||||
evaluation_cfg = config.evaluation
|
|
||||||
if self.task_config.evaluation_overrides:
|
|
||||||
evaluation_cfg = self._apply_simple_overrides(evaluation_cfg, self.task_config.evaluation_overrides)
|
|
||||||
visualization_cfg = config.visualization
|
|
||||||
if self.task_config.visualization_overrides:
|
|
||||||
visualization_cfg = self._apply_simple_overrides(
|
|
||||||
visualization_cfg, self.task_config.visualization_overrides
|
|
||||||
)
|
|
||||||
model_cfg = config.model
|
|
||||||
if self.task_config.model_overrides:
|
|
||||||
model_cfg = self._apply_simple_overrides(model_cfg, self.task_config.model_overrides)
|
|
||||||
training_cfg = config.training
|
|
||||||
if self.task_config.training_overrides:
|
|
||||||
training_cfg = self._apply_simple_overrides(training_cfg, self.task_config.training_overrides)
|
|
||||||
return replace(
|
|
||||||
config,
|
|
||||||
dataset=dataset_cfg,
|
|
||||||
model=model_cfg,
|
|
||||||
training=training_cfg,
|
|
||||||
evaluation=evaluation_cfg,
|
|
||||||
visualization=visualization_cfg,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _apply_dataset_overrides(self, dataset_cfg: DatasetConfig, overrides: Dict[str, Any]) -> DatasetConfig:
|
|
||||||
overrides = dict(overrides)
|
|
||||||
extra_updates = overrides.pop("extra_params", {})
|
|
||||||
merged_extra = dict(dataset_cfg.extra_params or {})
|
|
||||||
merged_extra.update(extra_updates)
|
|
||||||
return replace(dataset_cfg, **overrides, extra_params=merged_extra)
|
|
||||||
|
|
||||||
def _apply_simple_overrides(self, cfg, overrides: Dict[str, Any]):
|
|
||||||
overrides = dict(overrides)
|
|
||||||
return replace(cfg, **overrides)
|
|
||||||
|
|
||||||
def _default_data_root(self) -> str:
|
|
||||||
return self.project_config.dataset.data_root
|
|
||||||
|
|
||||||
def _default_test_file(self) -> str:
|
|
||||||
dataset_cfg = self.project_config.dataset
|
|
||||||
candidate = dataset_cfg.split_file or "test.txt"
|
|
||||||
candidate_path = Path(candidate)
|
|
||||||
if candidate_path.is_absolute():
|
|
||||||
return str(candidate_path)
|
|
||||||
return str(Path(dataset_cfg.data_root) / candidate)
|
|
||||||
|
|
||||||
def _default_output_dir(self) -> str:
|
|
||||||
return self.project_config.evaluation.output_dir
|
|
||||||
|
|
||||||
def _run_train(self, step: TaskStepConfig) -> None:
|
|
||||||
train_dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
|
||||||
eval_dataset = None
|
|
||||||
if step.eval_split:
|
|
||||||
eval_dataset = self._build_dataset(step.eval_split, step.eval_split_file)
|
|
||||||
trainer_builder = FineTuningTrainer(
|
|
||||||
adapter=self.adapter,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
training_config=_replace_dataclass(
|
|
||||||
self.project_config.training,
|
|
||||||
dict(step.params),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
artifacts = trainer_builder.build()
|
|
||||||
train_result = artifacts.trainer.train()
|
|
||||||
LOGGER.info("Training result: %s", train_result)
|
|
||||||
artifacts.trainer.save_model(self.project_config.training.output_dir)
|
|
||||||
if eval_dataset:
|
|
||||||
metrics = artifacts.trainer.evaluate()
|
|
||||||
LOGGER.info("Evaluation metrics: %s", metrics)
|
|
||||||
|
|
||||||
def _run_evaluate(self, step: TaskStepConfig) -> None:
|
|
||||||
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
|
||||||
evaluation_cfg = _replace_dataclass(
|
|
||||||
self.project_config.evaluation,
|
|
||||||
{**dict(step.params), "max_samples": step.limit},
|
|
||||||
)
|
|
||||||
evaluator = PipelineEvaluator(
|
|
||||||
dataset=dataset,
|
|
||||||
adapter=self.adapter,
|
|
||||||
config=evaluation_cfg,
|
|
||||||
)
|
|
||||||
summary = evaluator.run()
|
|
||||||
LOGGER.info("Evaluation summary: %s", summary)
|
|
||||||
|
|
||||||
def _run_visualize(self, step: TaskStepConfig) -> None:
|
|
||||||
dataset = self._build_dataset(step.dataset_split, step.dataset_split_file)
|
|
||||||
vis_config = _replace_dataclass(
|
|
||||||
self.project_config.visualization,
|
|
||||||
{**dict(step.params), "num_samples": step.limit or self.project_config.visualization.num_samples},
|
|
||||||
)
|
|
||||||
overlay = OverlayGenerator(vis_config)
|
|
||||||
pipe = self.adapter.build_pipeline()
|
|
||||||
limit = min(vis_config.num_samples, len(dataset))
|
|
||||||
for idx in range(limit):
|
|
||||||
sample = dataset[idx]
|
|
||||||
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
|
||||||
pred_mask = extract_mask_from_pipeline_output(preds)
|
|
||||||
mask = sample.get("labels", {}).get("mask")
|
|
||||||
overlay.visualize_sample(
|
|
||||||
image=sample["pixel_values"],
|
|
||||||
prediction=pred_mask,
|
|
||||||
mask=mask,
|
|
||||||
metadata=sample.get("metadata"),
|
|
||||||
)
|
|
||||||
LOGGER.info("Saved overlays to %s", vis_config.save_dir)
|
|
||||||
|
|
||||||
def _run_bbox_inference(self, step: TaskStepConfig) -> None:
|
|
||||||
params = dict(step.params)
|
|
||||||
data_root = params.get("data_root", self._default_data_root())
|
|
||||||
test_file = params.get("test_file", self._default_test_file())
|
|
||||||
expand_ratio = params.get("expand_ratio", params.get("bbox_expand_ratio", 0.05))
|
|
||||||
output_dir = params.get("output_dir", self._default_output_dir())
|
|
||||||
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
|
||||||
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
|
||||||
bbox_process_test_set(
|
|
||||||
data_root=data_root,
|
|
||||||
test_file=test_file,
|
|
||||||
predictor=predictor,
|
|
||||||
output_dir=output_dir,
|
|
||||||
expand_ratio=expand_ratio,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_point_inference(self, step: TaskStepConfig) -> None:
|
|
||||||
params = dict(step.params)
|
|
||||||
data_root = params.get("data_root", self._default_data_root())
|
|
||||||
test_file = params.get("test_file", self._default_test_file())
|
|
||||||
num_points = params.get("num_points", 5)
|
|
||||||
per_component = params.get("per_component", False)
|
|
||||||
output_dir = params.get("output_dir") or f"./results/point_prompt_{num_points}pts_hf"
|
|
||||||
model_id = params.get("model_id", self.project_config.model.name_or_path)
|
|
||||||
predictor = build_hf_sam2_predictor(model_id=model_id, device=params.get("device"))
|
|
||||||
point_process_test_set(
|
|
||||||
data_root=data_root,
|
|
||||||
test_file=test_file,
|
|
||||||
predictor=predictor,
|
|
||||||
output_dir=output_dir,
|
|
||||||
num_points=num_points,
|
|
||||||
per_component=per_component,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_legacy_evaluation(self, step: TaskStepConfig) -> None:
|
|
||||||
params = dict(step.params)
|
|
||||||
data_root = params.get("data_root", self._default_data_root())
|
|
||||||
test_file = params.get("test_file", self._default_test_file())
|
|
||||||
output_dir = params.get("output_dir", self._default_output_dir())
|
|
||||||
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
|
||||||
compute_skeleton = params.get("compute_skeleton", True)
|
|
||||||
legacy_evaluate_test_set(
|
|
||||||
data_root=data_root,
|
|
||||||
test_file=test_file,
|
|
||||||
pred_dir=pred_dir,
|
|
||||||
output_dir=output_dir,
|
|
||||||
compute_skeleton=compute_skeleton,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run_legacy_visualization(self, step: TaskStepConfig) -> None:
|
|
||||||
params = dict(step.params)
|
|
||||||
data_root = params.get("data_root", self._default_data_root())
|
|
||||||
test_file = params.get("test_file", self._default_test_file())
|
|
||||||
output_dir = params.get("output_dir", self._default_output_dir())
|
|
||||||
pred_dir = params.get("pred_dir", str(Path(output_dir) / "predictions"))
|
|
||||||
num_samples = params.get("num_samples", 20)
|
|
||||||
save_all = params.get("save_all", False)
|
|
||||||
results_csv = params.get("results_csv", str(Path(output_dir) / "evaluation_results.csv"))
|
|
||||||
legacy_visualize_test_set(
|
|
||||||
data_root=data_root,
|
|
||||||
test_file=test_file,
|
|
||||||
pred_dir=pred_dir,
|
|
||||||
output_dir=output_dir,
|
|
||||||
results_csv=results_csv if Path(results_csv).exists() else None,
|
|
||||||
num_samples=num_samples,
|
|
||||||
save_all=save_all,
|
|
||||||
)
|
|
||||||
if params.get("create_metrics_plot", True):
|
|
||||||
create_metrics_distribution_plot(results_csv, output_dir)
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Dict
|
|
||||||
|
|
||||||
from .config import TaskConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TaskRegistry:
|
|
||||||
"""
|
|
||||||
Holds named task configs for reuse.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Dict[str, TaskConfig] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, task: TaskConfig) -> TaskConfig:
|
|
||||||
cls._registry[task.name] = task
|
|
||||||
return task
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get(cls, name: str) -> TaskConfig:
|
|
||||||
if name not in cls._registry:
|
|
||||||
raise KeyError(f"Task '{name}' is not registered.")
|
|
||||||
return cls._registry[name]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def available(cls) -> Dict[str, TaskConfig]:
|
|
||||||
return dict(cls._registry)
|
|
||||||
@ -1,44 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
from .config import TaskConfig
|
|
||||||
from .io import load_task_from_toml
|
|
||||||
from .pipeline import TaskRunner
|
|
||||||
from .registry import TaskRegistry
|
|
||||||
|
|
||||||
# ensure built-in tasks are registered when CLI runs
|
|
||||||
from . import examples # noqa: F401
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TaskCLIArguments:
|
|
||||||
task_name: Optional[str] = None
|
|
||||||
task_file: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_task(cli_args: TaskCLIArguments) -> TaskConfig:
|
|
||||||
if not cli_args.task_name and not cli_args.task_file:
|
|
||||||
raise ValueError("Provide either --task_name or --task_file.")
|
|
||||||
if cli_args.task_file:
|
|
||||||
return load_task_from_toml(cli_args.task_file)
|
|
||||||
return TaskRegistry.get(cli_args.task_name)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = HfArgumentParser(TaskCLIArguments)
|
|
||||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
|
||||||
task = resolve_task(cli_args)
|
|
||||||
runner = TaskRunner(task)
|
|
||||||
runner.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
main()
|
|
||||||
@ -81,17 +81,17 @@ def create_comparison_figure(
|
|||||||
|
|
||||||
# 原始图像
|
# 原始图像
|
||||||
axes[0, 0].imshow(image)
|
axes[0, 0].imshow(image)
|
||||||
axes[0, 0].set_title("Original Image", fontsize=16)
|
axes[0, 0].set_title("Original Image", fontsize=12)
|
||||||
axes[0, 0].axis('off')
|
axes[0, 0].axis('off')
|
||||||
|
|
||||||
# GT 掩码
|
# GT 掩码
|
||||||
axes[0, 1].imshow(mask_gt, cmap='gray')
|
axes[0, 1].imshow(mask_gt, cmap='gray')
|
||||||
axes[0, 1].set_title("Ground Truth", fontsize=16)
|
axes[0, 1].set_title("Ground Truth", fontsize=12)
|
||||||
axes[0, 1].axis('off')
|
axes[0, 1].axis('off')
|
||||||
|
|
||||||
# 预测掩码
|
# 预测掩码
|
||||||
axes[1, 0].imshow(mask_pred, cmap='gray')
|
axes[1, 0].imshow(mask_pred, cmap='gray')
|
||||||
axes[1, 0].set_title("Prediction", fontsize=16)
|
axes[1, 0].set_title("Prediction", fontsize=12)
|
||||||
axes[1, 0].axis('off')
|
axes[1, 0].axis('off')
|
||||||
|
|
||||||
# 叠加可视化
|
# 叠加可视化
|
||||||
@ -110,16 +110,16 @@ def create_comparison_figure(
|
|||||||
axes[1, 1].text(
|
axes[1, 1].text(
|
||||||
0.02, 0.98, legend_text,
|
0.02, 0.98, legend_text,
|
||||||
transform=axes[1, 1].transAxes,
|
transform=axes[1, 1].transAxes,
|
||||||
fontsize=16,
|
fontsize=10,
|
||||||
verticalalignment='top',
|
verticalalignment='top',
|
||||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
|
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)
|
||||||
)
|
)
|
||||||
axes[1, 1].set_title("Overlay Visualization", fontsize=16)
|
axes[1, 1].set_title("Overlay Visualization", fontsize=12)
|
||||||
axes[1, 1].axis('off')
|
axes[1, 1].axis('off')
|
||||||
|
|
||||||
# # 设置总标题
|
# 设置总标题
|
||||||
# if title:
|
if title:
|
||||||
# fig.suptitle(title, fontsize=16, fontweight='bold')
|
fig.suptitle(title, fontsize=14, fontweight='bold')
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
@ -1,4 +0,0 @@
|
|||||||
from .gallery import build_gallery
|
|
||||||
from .overlay import OverlayGenerator
|
|
||||||
|
|
||||||
__all__ = ["OverlayGenerator", "build_gallery"]
|
|
||||||
@ -1,28 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
def build_gallery(image_paths: Iterable[Path], output_path: Path, columns: int = 4) -> Path:
|
|
||||||
"""
|
|
||||||
Simple grid composer that stitches overlay PNGs into a gallery.
|
|
||||||
"""
|
|
||||||
image_paths = list(image_paths)
|
|
||||||
if not image_paths:
|
|
||||||
raise ValueError("No images provided for gallery.")
|
|
||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
images = [Image.open(path).convert("RGB") for path in image_paths]
|
|
||||||
widths, heights = zip(*(img.size for img in images))
|
|
||||||
cell_w = max(widths)
|
|
||||||
cell_h = max(heights)
|
|
||||||
rows = (len(images) + columns - 1) // columns
|
|
||||||
canvas = Image.new("RGB", (cell_w * columns, cell_h * rows), color=(0, 0, 0))
|
|
||||||
for idx, img in enumerate(images):
|
|
||||||
row = idx // columns
|
|
||||||
col = idx % columns
|
|
||||||
canvas.paste(img, (col * cell_w, row * cell_h))
|
|
||||||
canvas.save(output_path)
|
|
||||||
return output_path
|
|
||||||
@ -1,62 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
from ..model_configuration import VisualizationConfig
|
|
||||||
|
|
||||||
|
|
||||||
class OverlayGenerator:
|
|
||||||
"""
|
|
||||||
Turns model predictions into side-by-side overlays for quick inspection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: VisualizationConfig) -> None:
|
|
||||||
self.config = config
|
|
||||||
Path(self.config.save_dir).mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
def visualize_sample(
|
|
||||||
self,
|
|
||||||
image: np.ndarray,
|
|
||||||
prediction: np.ndarray,
|
|
||||||
mask: Optional[np.ndarray],
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Path:
|
|
||||||
overlay = self._compose_overlay(image, prediction, mask)
|
|
||||||
filename = (
|
|
||||||
metadata.get("image_name", "sample")
|
|
||||||
if metadata
|
|
||||||
else "sample"
|
|
||||||
)
|
|
||||||
target = Path(self.config.save_dir) / f"{filename}_overlay.png"
|
|
||||||
Image.fromarray(overlay).save(target)
|
|
||||||
return target
|
|
||||||
|
|
||||||
def _compose_overlay(
|
|
||||||
self,
|
|
||||||
image: np.ndarray,
|
|
||||||
prediction: np.ndarray,
|
|
||||||
mask: Optional[np.ndarray],
|
|
||||||
) -> np.ndarray:
|
|
||||||
vis = image.copy()
|
|
||||||
pred_mask = self._normalize(prediction)
|
|
||||||
color = np.zeros_like(vis)
|
|
||||||
color[..., 0] = pred_mask
|
|
||||||
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
|
||||||
if mask is not None:
|
|
||||||
gt = self._normalize(mask)
|
|
||||||
color = np.zeros_like(vis)
|
|
||||||
color[..., 1] = gt
|
|
||||||
vis = (0.5 * vis + 0.5 * color).astype(np.uint8)
|
|
||||||
return vis
|
|
||||||
|
|
||||||
def _normalize(self, array: np.ndarray) -> np.ndarray:
|
|
||||||
normalized = array.astype(np.float32)
|
|
||||||
normalized -= normalized.min()
|
|
||||||
denom = normalized.max() or 1.0
|
|
||||||
normalized = normalized / denom
|
|
||||||
normalized = (normalized * 255).astype(np.uint8)
|
|
||||||
return normalized
|
|
||||||
@ -1,58 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, replace
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
from ..dataset import DatasetRegistry
|
|
||||||
from ..evaluation.utils import extract_mask_from_pipeline_output
|
|
||||||
from ..model import ModelRegistry
|
|
||||||
from ..model_configuration import ConfigRegistry
|
|
||||||
from .overlay import OverlayGenerator
|
|
||||||
|
|
||||||
LOGGER = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VisualizationCLIArguments:
|
|
||||||
config_name: str = "sam2_bbox_prompt"
|
|
||||||
model_key: str = "sam2"
|
|
||||||
split: str = "test"
|
|
||||||
split_file: Optional[str] = None
|
|
||||||
num_samples: int = 20
|
|
||||||
device: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = HfArgumentParser(VisualizationCLIArguments)
|
|
||||||
(cli_args,) = parser.parse_args_into_dataclasses()
|
|
||||||
project_config = ConfigRegistry.get(cli_args.config_name)
|
|
||||||
dataset_cfg = replace(project_config.dataset, split=cli_args.split, split_file=cli_args.split_file)
|
|
||||||
dataset = DatasetRegistry.create(
|
|
||||||
dataset_cfg.name,
|
|
||||||
config=dataset_cfg,
|
|
||||||
return_hf_dict=True,
|
|
||||||
)
|
|
||||||
adapter = ModelRegistry.create(cli_args.model_key, project_config.model)
|
|
||||||
overlay = OverlayGenerator(project_config.visualization)
|
|
||||||
pipe = adapter.build_pipeline(device=cli_args.device)
|
|
||||||
limit = min(cli_args.num_samples, len(dataset))
|
|
||||||
for idx in range(limit):
|
|
||||||
sample = dataset[idx]
|
|
||||||
preds = pipe(pixel_values=sample["pixel_values"], prompts=sample.get("prompts"))
|
|
||||||
pred_mask = extract_mask_from_pipeline_output(preds)
|
|
||||||
mask = sample.get("labels", {}).get("mask")
|
|
||||||
overlay.visualize_sample(
|
|
||||||
image=sample["pixel_values"],
|
|
||||||
prediction=pred_mask,
|
|
||||||
mask=mask,
|
|
||||||
metadata=sample.get("metadata"),
|
|
||||||
)
|
|
||||||
LOGGER.info("Saved overlays to %s", project_config.visualization.save_dir)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
main()
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
[task]
|
|
||||||
name = "bbox_cli_template"
|
|
||||||
description = "Run legacy bbox-prompt inference + evaluation + visualization"
|
|
||||||
project_config_name = "sam2_bbox_prompt"
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "bbox_inference"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
model_id = "facebook/sam2-hiera-small"
|
|
||||||
output_dir = "./results/bbox_prompt"
|
|
||||||
expand_ratio = 0.05
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_evaluation"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/bbox_prompt"
|
|
||||||
pred_dir = "./results/bbox_prompt/predictions"
|
|
||||||
compute_skeleton = true
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_visualization"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/bbox_prompt"
|
|
||||||
pred_dir = "./results/bbox_prompt/predictions"
|
|
||||||
results_csv = "./results/bbox_prompt/evaluation_results.csv"
|
|
||||||
num_samples = 20
|
|
||||||
save_all = false
|
|
||||||
create_metrics_plot = true
|
|
||||||
@ -1,100 +0,0 @@
|
|||||||
[task]
|
|
||||||
name = "point_cli_template"
|
|
||||||
description = "Run legacy point-prompt inference/eval/visualization for multiple configs"
|
|
||||||
project_config_name = "sam2_bbox_prompt"
|
|
||||||
|
|
||||||
# 1 point config
|
|
||||||
[[steps]]
|
|
||||||
kind = "point_inference"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
model_id = "facebook/sam2-hiera-small"
|
|
||||||
num_points = 1
|
|
||||||
per_component = false
|
|
||||||
output_dir = "./results/point_prompt_1pts_hf"
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_evaluation"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/point_prompt_1pts_hf"
|
|
||||||
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
|
||||||
compute_skeleton = true
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_visualization"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/point_prompt_1pts_hf"
|
|
||||||
pred_dir = "./results/point_prompt_1pts_hf/predictions"
|
|
||||||
results_csv = "./results/point_prompt_1pts_hf/evaluation_results.csv"
|
|
||||||
num_samples = 10
|
|
||||||
save_all = false
|
|
||||||
create_metrics_plot = true
|
|
||||||
|
|
||||||
# 3 point config
|
|
||||||
[[steps]]
|
|
||||||
kind = "point_inference"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
model_id = "facebook/sam2-hiera-small"
|
|
||||||
num_points = 3
|
|
||||||
per_component = false
|
|
||||||
output_dir = "./results/point_prompt_3pts_hf"
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_evaluation"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/point_prompt_3pts_hf"
|
|
||||||
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
|
||||||
compute_skeleton = true
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_visualization"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/point_prompt_3pts_hf"
|
|
||||||
pred_dir = "./results/point_prompt_3pts_hf/predictions"
|
|
||||||
results_csv = "./results/point_prompt_3pts_hf/evaluation_results.csv"
|
|
||||||
num_samples = 10
|
|
||||||
save_all = false
|
|
||||||
create_metrics_plot = true
|
|
||||||
|
|
||||||
# 5 point config
|
|
||||||
[[steps]]
|
|
||||||
kind = "point_inference"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
model_id = "facebook/sam2-hiera-small"
|
|
||||||
num_points = 5
|
|
||||||
per_component = false
|
|
||||||
output_dir = "./results/point_prompt_5pts_hf"
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_evaluation"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/point_prompt_5pts_hf"
|
|
||||||
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
|
||||||
compute_skeleton = true
|
|
||||||
|
|
||||||
[[steps]]
|
|
||||||
kind = "legacy_visualization"
|
|
||||||
[steps.params]
|
|
||||||
data_root = "./crack500"
|
|
||||||
test_file = "./crack500/test.txt"
|
|
||||||
output_dir = "./results/point_prompt_5pts_hf"
|
|
||||||
pred_dir = "./results/point_prompt_5pts_hf/predictions"
|
|
||||||
results_csv = "./results/point_prompt_5pts_hf/evaluation_results.csv"
|
|
||||||
num_samples = 10
|
|
||||||
save_all = false
|
|
||||||
create_metrics_plot = true
|
|
||||||
Loading…
x
Reference in New Issue
Block a user