sam_crack/evaluate_lora.py
2025-12-24 17:15:36 +08:00

155 lines
5.0 KiB
Python

#!/usr/bin/env python3
"""
Evaluate LoRA fine-tuned SAM2 model on Crack500 test set
Uses existing evaluation infrastructure from bbox_prompt.py and evaluation.py
"""
import torch
import argparse
from pathlib import Path
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from src.lora_utils import apply_lora_to_model, load_lora_weights
from src.bbox_prompt import process_test_set
from src.evaluation import evaluate_test_set
from src.visualization import visualize_test_set
from configs.lora_configs import get_strategy
def load_lora_model(args):
"""Load SAM2 model with LoRA weights"""
print(f"Loading base SAM2 model: {args.model_cfg}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load base model
sam2_model = build_sam2(
args.model_cfg,
args.base_checkpoint,
device=device,
mode="eval"
)
# Apply LoRA configuration
lora_config = get_strategy(args.lora_strategy)
print(f"\nApplying LoRA strategy: {lora_config.name}")
sam2_model = apply_lora_to_model(
sam2_model,
target_modules=lora_config.target_modules,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout
)
# Load LoRA weights
print(f"\nLoading LoRA weights from: {args.lora_checkpoint}")
checkpoint = torch.load(args.lora_checkpoint, map_location=device)
if 'model_state_dict' in checkpoint:
sam2_model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
if 'metrics' in checkpoint:
print(f"Checkpoint metrics: {checkpoint['metrics']}")
else:
# Load LoRA weights only
load_lora_weights(sam2_model, args.lora_checkpoint)
sam2_model.eval()
return sam2_model
def main():
parser = argparse.ArgumentParser(description='Evaluate LoRA fine-tuned SAM2')
# Model parameters
parser.add_argument('--lora_checkpoint', type=str, required=True,
help='Path to LoRA checkpoint')
parser.add_argument('--base_checkpoint', type=str,
default='../sam2/checkpoints/sam2.1_hiera_small.pt')
parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml')
parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C'])
# Data parameters
parser.add_argument('--data_root', type=str, default='./crack500')
parser.add_argument('--test_file', type=str, default='./crack500/test.txt')
parser.add_argument('--expand_ratio', type=float, default=0.05)
# Output parameters
parser.add_argument('--output_dir', type=str, default='./results/lora_eval')
parser.add_argument('--num_vis', type=int, default=20,
help='Number of samples to visualize')
parser.add_argument('--vis_all', action='store_true',
help='Visualize all samples')
# Workflow control
parser.add_argument('--skip_inference', action='store_true')
parser.add_argument('--skip_evaluation', action='store_true')
parser.add_argument('--skip_visualization', action='store_true')
args = parser.parse_args()
# Create output directory
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load model
if not args.skip_inference:
print("\n" + "="*60)
print("Loading LoRA model...")
print("="*60)
model = load_lora_model(args)
predictor = SAM2ImagePredictor(model)
# Run inference
print("\n" + "="*60)
print("Running inference on test set...")
print("="*60)
process_test_set(
data_root=args.data_root,
test_file=args.test_file,
predictor=predictor,
output_dir=str(output_dir),
expand_ratio=args.expand_ratio
)
# Evaluate
if not args.skip_evaluation:
print("\n" + "="*60)
print("Evaluating predictions...")
print("="*60)
pred_dir = output_dir / 'predictions'
evaluate_test_set(
data_root=args.data_root,
test_file=args.test_file,
pred_dir=str(pred_dir),
output_dir=str(output_dir)
)
# Visualize
if not args.skip_visualization:
print("\n" + "="*60)
print("Creating visualizations...")
print("="*60)
pred_dir = output_dir / 'predictions'
visualize_test_set(
data_root=args.data_root,
test_file=args.test_file,
pred_dir=str(pred_dir),
output_dir=str(output_dir),
num_vis=args.num_vis if not args.vis_all else None
)
print("\n" + "="*60)
print("Evaluation complete!")
print(f"Results saved to: {output_dir}")
print("="*60)
if __name__ == '__main__':
main()