# CLAUDE.md This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. ## Project Overview SAM2 (Segment Anything Model 2) evaluation project for crack segmentation on the Crack500 dataset. The project evaluates SAM2's performance using different prompting strategies: bounding box prompts and point prompts (1/3/5 points). ## Environment Setup This project uses Pixi for dependency management: ```bash # Install dependencies pixi install # Activate environment pixi shell ``` Key dependencies are defined in `pixi.toml`. SAM2 is installed as an editable package from `/home/dustella/projects/sam2`. ## Running Evaluations ### Bounding Box Prompt Evaluation ```bash # Full pipeline (inference + evaluation + visualization) python run_bbox_evaluation.py # With custom parameters python run_bbox_evaluation.py \ --data_root ./crack500 \ --test_file ./crack500/test.txt \ --expand_ratio 0.05 \ --output_dir ./results/bbox_prompt \ --num_vis 20 # Skip specific steps python run_bbox_evaluation.py --skip_inference # Use existing predictions python run_bbox_evaluation.py --skip_evaluation # Skip metrics calculation python run_bbox_evaluation.py --skip_visualization # Skip visualization ``` ### Point Prompt Evaluation ```bash # Evaluate with 1, 3, and 5 points python run_point_evaluation.py \ --data_root ./crack500 \ --test_file ./crack500/test.txt \ --point_configs 1 3 5 \ --per_component ``` ## Architecture ### Core Modules **src/bbox_prompt.py**: Bounding box prompting implementation - `extract_bboxes_from_mask()`: Extracts bounding boxes from GT masks using connected component analysis - `predict_with_bbox_prompt()`: Runs SAM2 inference with bbox prompts - `process_test_set()`: Batch processes the test set **src/point_prompt.py**: Point prompting implementation - Similar structure to bbox_prompt.py but uses point coordinates as prompts - Supports multiple points per connected component **src/evaluation.py**: Metrics computation - `compute_iou()`: Intersection over Union - `compute_dice()`: Dice coefficient - `compute_precision_recall()`: Precision and Recall - `compute_skeleton_iou()`: Skeleton IoU for thin crack structures **src/visualization.py**: Visualization utilities - `create_overlay_visualization()`: Creates TP/FP/FN overlays (yellow/red/green) - `create_comparison_figure()`: 4-panel comparison (original, GT, prediction, overlay) - `create_metrics_distribution_plot()`: Plots metric distributions ### Data Format Test set file (`crack500/test.txt`) format: ``` testcrop/image_name.jpg testcrop/mask_name.png ``` Images are in `crack500/testcrop/`, masks in `crack500/testdata/`. ### Prompting Strategies **Bounding Box**: Extracts bboxes from GT masks via connected components, optionally expands by `expand_ratio` (default 5%) to simulate imprecise annotations. **Point Prompts**: Samples N points (1/3/5) from each connected component in GT mask. When `--per_component` is used, points are sampled per component; otherwise, globally. ## Model Configuration Default model: `sam2.1_hiera_small.pt` - Located at: `../sam2/checkpoints/sam2.1_hiera_small.pt` - Config: `configs/sam2.1/sam2.1_hiera_s.yaml` Alternative models: - `sam2.1_hiera_tiny.pt`: Fastest, lower accuracy - `sam2.1_hiera_large.pt`: Highest accuracy, slower ## Results Structure ``` results/ └── bbox_prompt/ ├── predictions/ # Predicted masks (.png) ├── visualizations/ # Comparison figures ├── evaluation_results.csv └── evaluation_summary.json ``` ## Performance Benchmarks Current SAM2 results on Crack500 (from note.md): - Bbox prompt: 39.60% IoU, 53.58% F1 - 1 point: 8.43% IoU, 12.70% F1 - 3 points: 33.35% IoU, 45.94% F1 - 5 points: 38.50% IoU, 51.89% F1 Compared to specialized models (CrackSegMamba: 57.4% IoU, 72.9% F1), SAM2 underperforms without fine-tuning. ## Key Implementation Details - Connected component analysis uses `cv2.connectedComponentsWithStats()` with 8-connectivity - Minimum component area threshold: 10 pixels (filters noise) - SAM2 inference uses `multimask_output=False` for single mask per prompt - Multiple bboxes/points per image are combined with logical OR - Masks are binary (0 or 255) ## LoRA Fine-tuning The project includes LoRA (Low-Rank Adaptation) fine-tuning to improve SAM2 performance on crack segmentation. ### Quick Start ```bash # Verify dataset python scripts/prepare_data.py # Train with Strategy B (recommended) python train_lora.py \ --lora_strategy B \ --epochs 50 \ --batch_size 4 \ --gradient_accumulation_steps 4 \ --use_skeleton_loss \ --use_wandb # Evaluate fine-tuned model python evaluate_lora.py \ --lora_checkpoint ./results/lora_training/checkpoints/best_model.pt \ --lora_strategy B ``` ### LoRA Strategies **Strategy A: Decoder-Only** (Fast, ~2-3 hours) - Targets only mask decoder attention layers - ~2-3M trainable parameters - Expected improvement: 45-50% IoU ```bash python train_lora.py --lora_strategy A --epochs 30 --batch_size 8 --gradient_accumulation_steps 2 ``` **Strategy B: Decoder + Encoder** (Recommended, ~4-6 hours) - Targets mask decoder + image encoder attention layers - ~5-7M trainable parameters - Expected improvement: 50-55% IoU ```bash python train_lora.py --lora_strategy B --epochs 50 --batch_size 4 --gradient_accumulation_steps 4 ``` **Strategy C: Full Model** (Maximum, ~8-12 hours) - Targets all attention + FFN layers - ~10-15M trainable parameters - Expected improvement: 55-60% IoU ```bash python train_lora.py --lora_strategy C --epochs 80 --batch_size 2 --gradient_accumulation_steps 8 --learning_rate 3e-5 ``` ### Training Features - **Gradient Accumulation**: Simulates larger batch sizes with limited GPU memory - **Mixed Precision**: Automatic mixed precision (AMP) for memory efficiency - **Early Stopping**: Stops training when validation IoU plateaus (patience=10) - **Skeleton Loss**: Additional loss for thin crack structures - **W&B Logging**: Track experiments with Weights & Biases (optional) ### Key Files **Training Infrastructure:** - `train_lora.py`: Main training script - `evaluate_lora.py`: Evaluation script for fine-tuned models - `src/dataset.py`: Dataset with automatic bbox prompt generation - `src/losses.py`: Combined Dice + Focal loss, Skeleton loss - `src/lora_utils.py`: Custom LoRA implementation (no PEFT dependency) **Configuration:** - `configs/lora_configs.py`: Three LoRA strategies (A, B, C) - `configs/train_config.yaml`: Default hyperparameters ### Training Process 1. **Data Preparation**: Automatically generates bbox prompts from GT masks during training 2. **Forward Pass**: Keeps all data on GPU (no CPU-GPU transfers during training) 3. **Loss Computation**: Combined Dice + Focal loss for class imbalance, optional Skeleton loss 4. **Optimization**: AdamW optimizer with cosine annealing schedule 5. **Validation**: Computes IoU, Dice, F1 on validation set every epoch 6. **Checkpointing**: Saves best model based on validation IoU ### Expected Results **Baseline (No Fine-tuning):** - Bbox prompt: 39.60% IoU, 53.58% F1 **After Fine-tuning:** - Strategy A: 45-50% IoU (modest improvement) - Strategy B: 50-55% IoU (significant improvement) - Strategy C: 55-60% IoU (approaching SOTA) **SOTA Comparison:** - CrackSegMamba: 57.4% IoU, 72.9% F1 ### Important Notes - Uses official Meta SAM2 implementation (not transformers library) - Custom LoRA implementation to avoid dependency issues - All training data stays on GPU for efficiency - LoRA weights can be saved separately (~10-50MB vs full model ~200MB) - Compatible with existing evaluation infrastructure