sam_crack/CLAUDE.md
2025-12-24 17:15:36 +08:00

7.6 KiB

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

SAM2 (Segment Anything Model 2) evaluation project for crack segmentation on the Crack500 dataset. The project evaluates SAM2's performance using different prompting strategies: bounding box prompts and point prompts (1/3/5 points).

Environment Setup

This project uses Pixi for dependency management:

# Install dependencies
pixi install

# Activate environment
pixi shell

Key dependencies are defined in pixi.toml. SAM2 is installed as an editable package from /home/dustella/projects/sam2.

Running Evaluations

Bounding Box Prompt Evaluation

# Full pipeline (inference + evaluation + visualization)
python run_bbox_evaluation.py

# With custom parameters
python run_bbox_evaluation.py \
    --data_root ./crack500 \
    --test_file ./crack500/test.txt \
    --expand_ratio 0.05 \
    --output_dir ./results/bbox_prompt \
    --num_vis 20

# Skip specific steps
python run_bbox_evaluation.py --skip_inference      # Use existing predictions
python run_bbox_evaluation.py --skip_evaluation     # Skip metrics calculation
python run_bbox_evaluation.py --skip_visualization  # Skip visualization

Point Prompt Evaluation

# Evaluate with 1, 3, and 5 points
python run_point_evaluation.py \
    --data_root ./crack500 \
    --test_file ./crack500/test.txt \
    --point_configs 1 3 5 \
    --per_component

Architecture

Core Modules

src/bbox_prompt.py: Bounding box prompting implementation

  • extract_bboxes_from_mask(): Extracts bounding boxes from GT masks using connected component analysis
  • predict_with_bbox_prompt(): Runs SAM2 inference with bbox prompts
  • process_test_set(): Batch processes the test set

src/point_prompt.py: Point prompting implementation

  • Similar structure to bbox_prompt.py but uses point coordinates as prompts
  • Supports multiple points per connected component

src/evaluation.py: Metrics computation

  • compute_iou(): Intersection over Union
  • compute_dice(): Dice coefficient
  • compute_precision_recall(): Precision and Recall
  • compute_skeleton_iou(): Skeleton IoU for thin crack structures

src/visualization.py: Visualization utilities

  • create_overlay_visualization(): Creates TP/FP/FN overlays (yellow/red/green)
  • create_comparison_figure(): 4-panel comparison (original, GT, prediction, overlay)
  • create_metrics_distribution_plot(): Plots metric distributions

Data Format

Test set file (crack500/test.txt) format:

testcrop/image_name.jpg testcrop/mask_name.png

Images are in crack500/testcrop/, masks in crack500/testdata/.

Prompting Strategies

Bounding Box: Extracts bboxes from GT masks via connected components, optionally expands by expand_ratio (default 5%) to simulate imprecise annotations.

Point Prompts: Samples N points (1/3/5) from each connected component in GT mask. When --per_component is used, points are sampled per component; otherwise, globally.

Model Configuration

Default model: sam2.1_hiera_small.pt

  • Located at: ../sam2/checkpoints/sam2.1_hiera_small.pt
  • Config: configs/sam2.1/sam2.1_hiera_s.yaml

Alternative models:

  • sam2.1_hiera_tiny.pt: Fastest, lower accuracy
  • sam2.1_hiera_large.pt: Highest accuracy, slower

Results Structure

results/
└── bbox_prompt/
    ├── predictions/         # Predicted masks (.png)
    ├── visualizations/      # Comparison figures
    ├── evaluation_results.csv
    └── evaluation_summary.json

Performance Benchmarks

Current SAM2 results on Crack500 (from note.md):

  • Bbox prompt: 39.60% IoU, 53.58% F1
  • 1 point: 8.43% IoU, 12.70% F1
  • 3 points: 33.35% IoU, 45.94% F1
  • 5 points: 38.50% IoU, 51.89% F1

Compared to specialized models (CrackSegMamba: 57.4% IoU, 72.9% F1), SAM2 underperforms without fine-tuning.

Key Implementation Details

  • Connected component analysis uses cv2.connectedComponentsWithStats() with 8-connectivity
  • Minimum component area threshold: 10 pixels (filters noise)
  • SAM2 inference uses multimask_output=False for single mask per prompt
  • Multiple bboxes/points per image are combined with logical OR
  • Masks are binary (0 or 255)

LoRA Fine-tuning

The project includes LoRA (Low-Rank Adaptation) fine-tuning to improve SAM2 performance on crack segmentation.

Quick Start

# Verify dataset
python scripts/prepare_data.py

# Train with Strategy B (recommended)
python train_lora.py \
    --lora_strategy B \
    --epochs 50 \
    --batch_size 4 \
    --gradient_accumulation_steps 4 \
    --use_skeleton_loss \
    --use_wandb

# Evaluate fine-tuned model
python evaluate_lora.py \
    --lora_checkpoint ./results/lora_training/checkpoints/best_model.pt \
    --lora_strategy B

LoRA Strategies

Strategy A: Decoder-Only (Fast, ~2-3 hours)

  • Targets only mask decoder attention layers
  • ~2-3M trainable parameters
  • Expected improvement: 45-50% IoU
python train_lora.py --lora_strategy A --epochs 30 --batch_size 8 --gradient_accumulation_steps 2

Strategy B: Decoder + Encoder (Recommended, ~4-6 hours)

  • Targets mask decoder + image encoder attention layers
  • ~5-7M trainable parameters
  • Expected improvement: 50-55% IoU
python train_lora.py --lora_strategy B --epochs 50 --batch_size 4 --gradient_accumulation_steps 4

Strategy C: Full Model (Maximum, ~8-12 hours)

  • Targets all attention + FFN layers
  • ~10-15M trainable parameters
  • Expected improvement: 55-60% IoU
python train_lora.py --lora_strategy C --epochs 80 --batch_size 2 --gradient_accumulation_steps 8 --learning_rate 3e-5

Training Features

  • Gradient Accumulation: Simulates larger batch sizes with limited GPU memory
  • Mixed Precision: Automatic mixed precision (AMP) for memory efficiency
  • Early Stopping: Stops training when validation IoU plateaus (patience=10)
  • Skeleton Loss: Additional loss for thin crack structures
  • W&B Logging: Track experiments with Weights & Biases (optional)

Key Files

Training Infrastructure:

  • train_lora.py: Main training script
  • evaluate_lora.py: Evaluation script for fine-tuned models
  • src/dataset.py: Dataset with automatic bbox prompt generation
  • src/losses.py: Combined Dice + Focal loss, Skeleton loss
  • src/lora_utils.py: Custom LoRA implementation (no PEFT dependency)

Configuration:

  • configs/lora_configs.py: Three LoRA strategies (A, B, C)
  • configs/train_config.yaml: Default hyperparameters

Training Process

  1. Data Preparation: Automatically generates bbox prompts from GT masks during training
  2. Forward Pass: Keeps all data on GPU (no CPU-GPU transfers during training)
  3. Loss Computation: Combined Dice + Focal loss for class imbalance, optional Skeleton loss
  4. Optimization: AdamW optimizer with cosine annealing schedule
  5. Validation: Computes IoU, Dice, F1 on validation set every epoch
  6. Checkpointing: Saves best model based on validation IoU

Expected Results

Baseline (No Fine-tuning):

  • Bbox prompt: 39.60% IoU, 53.58% F1

After Fine-tuning:

  • Strategy A: 45-50% IoU (modest improvement)
  • Strategy B: 50-55% IoU (significant improvement)
  • Strategy C: 55-60% IoU (approaching SOTA)

SOTA Comparison:

  • CrackSegMamba: 57.4% IoU, 72.9% F1

Important Notes

  • Uses official Meta SAM2 implementation (not transformers library)
  • Custom LoRA implementation to avoid dependency issues
  • All training data stays on GPU for efficiency
  • LoRA weights can be saved separately (~10-50MB vs full model ~200MB)
  • Compatible with existing evaluation infrastructure