#!/usr/bin/env python3 """ Evaluate LoRA fine-tuned SAM2 model on Crack500 test set Uses existing evaluation infrastructure from bbox_prompt.py and evaluation.py """ import torch import argparse from pathlib import Path import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor from src.lora_utils import apply_lora_to_model, load_lora_weights from src.bbox_prompt import process_test_set from src.evaluation import evaluate_test_set from src.visualization import visualize_test_set from configs.lora_configs import get_strategy def load_lora_model(args): """Load SAM2 model with LoRA weights""" print(f"Loading base SAM2 model: {args.model_cfg}") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load base model sam2_model = build_sam2( args.model_cfg, args.base_checkpoint, device=device, mode="eval" ) # Apply LoRA configuration lora_config = get_strategy(args.lora_strategy) print(f"\nApplying LoRA strategy: {lora_config.name}") sam2_model = apply_lora_to_model( sam2_model, target_modules=lora_config.target_modules, r=lora_config.r, lora_alpha=lora_config.lora_alpha, lora_dropout=lora_config.lora_dropout ) # Load LoRA weights print(f"\nLoading LoRA weights from: {args.lora_checkpoint}") checkpoint = torch.load(args.lora_checkpoint, map_location=device) if 'model_state_dict' in checkpoint: sam2_model.load_state_dict(checkpoint['model_state_dict']) print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") if 'metrics' in checkpoint: print(f"Checkpoint metrics: {checkpoint['metrics']}") else: # Load LoRA weights only load_lora_weights(sam2_model, args.lora_checkpoint) sam2_model.eval() return sam2_model def main(): parser = argparse.ArgumentParser(description='Evaluate LoRA fine-tuned SAM2') # Model parameters parser.add_argument('--lora_checkpoint', type=str, required=True, help='Path to LoRA checkpoint') parser.add_argument('--base_checkpoint', type=str, default='../sam2/checkpoints/sam2.1_hiera_small.pt') parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml') parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C']) # Data parameters parser.add_argument('--data_root', type=str, default='./crack500') parser.add_argument('--test_file', type=str, default='./crack500/test.txt') parser.add_argument('--expand_ratio', type=float, default=0.05) # Output parameters parser.add_argument('--output_dir', type=str, default='./results/lora_eval') parser.add_argument('--num_vis', type=int, default=20, help='Number of samples to visualize') parser.add_argument('--vis_all', action='store_true', help='Visualize all samples') # Workflow control parser.add_argument('--skip_inference', action='store_true') parser.add_argument('--skip_evaluation', action='store_true') parser.add_argument('--skip_visualization', action='store_true') args = parser.parse_args() # Create output directory output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Load model if not args.skip_inference: print("\n" + "="*60) print("Loading LoRA model...") print("="*60) model = load_lora_model(args) predictor = SAM2ImagePredictor(model) # Run inference print("\n" + "="*60) print("Running inference on test set...") print("="*60) process_test_set( data_root=args.data_root, test_file=args.test_file, predictor=predictor, output_dir=str(output_dir), expand_ratio=args.expand_ratio ) # Evaluate if not args.skip_evaluation: print("\n" + "="*60) print("Evaluating predictions...") print("="*60) pred_dir = output_dir / 'predictions' evaluate_test_set( data_root=args.data_root, test_file=args.test_file, pred_dir=str(pred_dir), output_dir=str(output_dir) ) # Visualize if not args.skip_visualization: print("\n" + "="*60) print("Creating visualizations...") print("="*60) pred_dir = output_dir / 'predictions' visualize_test_set( data_root=args.data_root, test_file=args.test_file, pred_dir=str(pred_dir), output_dir=str(output_dir), num_vis=args.num_vis if not args.vis_all else None ) print("\n" + "="*60) print("Evaluation complete!") print(f"Results saved to: {output_dir}") print("="*60) if __name__ == '__main__': main()