155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Evaluate LoRA fine-tuned SAM2 model on Crack500 test set
|
|
Uses existing evaluation infrastructure from bbox_prompt.py and evaluation.py
|
|
"""
|
|
|
|
import torch
|
|
import argparse
|
|
from pathlib import Path
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
|
|
|
from sam2.build_sam import build_sam2
|
|
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
|
from src.lora_utils import apply_lora_to_model, load_lora_weights
|
|
from src.bbox_prompt import process_test_set
|
|
from src.evaluation import evaluate_test_set
|
|
from src.visualization import visualize_test_set
|
|
from configs.lora_configs import get_strategy
|
|
|
|
|
|
def load_lora_model(args):
|
|
"""Load SAM2 model with LoRA weights"""
|
|
print(f"Loading base SAM2 model: {args.model_cfg}")
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
# Load base model
|
|
sam2_model = build_sam2(
|
|
args.model_cfg,
|
|
args.base_checkpoint,
|
|
device=device,
|
|
mode="eval"
|
|
)
|
|
|
|
# Apply LoRA configuration
|
|
lora_config = get_strategy(args.lora_strategy)
|
|
print(f"\nApplying LoRA strategy: {lora_config.name}")
|
|
|
|
sam2_model = apply_lora_to_model(
|
|
sam2_model,
|
|
target_modules=lora_config.target_modules,
|
|
r=lora_config.r,
|
|
lora_alpha=lora_config.lora_alpha,
|
|
lora_dropout=lora_config.lora_dropout
|
|
)
|
|
|
|
# Load LoRA weights
|
|
print(f"\nLoading LoRA weights from: {args.lora_checkpoint}")
|
|
checkpoint = torch.load(args.lora_checkpoint, map_location=device)
|
|
|
|
if 'model_state_dict' in checkpoint:
|
|
sam2_model.load_state_dict(checkpoint['model_state_dict'])
|
|
print(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
|
|
if 'metrics' in checkpoint:
|
|
print(f"Checkpoint metrics: {checkpoint['metrics']}")
|
|
else:
|
|
# Load LoRA weights only
|
|
load_lora_weights(sam2_model, args.lora_checkpoint)
|
|
|
|
sam2_model.eval()
|
|
return sam2_model
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='Evaluate LoRA fine-tuned SAM2')
|
|
|
|
# Model parameters
|
|
parser.add_argument('--lora_checkpoint', type=str, required=True,
|
|
help='Path to LoRA checkpoint')
|
|
parser.add_argument('--base_checkpoint', type=str,
|
|
default='../sam2/checkpoints/sam2.1_hiera_small.pt')
|
|
parser.add_argument('--model_cfg', type=str, default='configs/sam2.1/sam2.1_hiera_s.yaml')
|
|
parser.add_argument('--lora_strategy', type=str, default='B', choices=['A', 'B', 'C'])
|
|
|
|
# Data parameters
|
|
parser.add_argument('--data_root', type=str, default='./crack500')
|
|
parser.add_argument('--test_file', type=str, default='./crack500/test.txt')
|
|
parser.add_argument('--expand_ratio', type=float, default=0.05)
|
|
|
|
# Output parameters
|
|
parser.add_argument('--output_dir', type=str, default='./results/lora_eval')
|
|
parser.add_argument('--num_vis', type=int, default=20,
|
|
help='Number of samples to visualize')
|
|
parser.add_argument('--vis_all', action='store_true',
|
|
help='Visualize all samples')
|
|
|
|
# Workflow control
|
|
parser.add_argument('--skip_inference', action='store_true')
|
|
parser.add_argument('--skip_evaluation', action='store_true')
|
|
parser.add_argument('--skip_visualization', action='store_true')
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create output directory
|
|
output_dir = Path(args.output_dir)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Load model
|
|
if not args.skip_inference:
|
|
print("\n" + "="*60)
|
|
print("Loading LoRA model...")
|
|
print("="*60)
|
|
model = load_lora_model(args)
|
|
predictor = SAM2ImagePredictor(model)
|
|
|
|
# Run inference
|
|
print("\n" + "="*60)
|
|
print("Running inference on test set...")
|
|
print("="*60)
|
|
process_test_set(
|
|
data_root=args.data_root,
|
|
test_file=args.test_file,
|
|
predictor=predictor,
|
|
output_dir=str(output_dir),
|
|
expand_ratio=args.expand_ratio
|
|
)
|
|
|
|
# Evaluate
|
|
if not args.skip_evaluation:
|
|
print("\n" + "="*60)
|
|
print("Evaluating predictions...")
|
|
print("="*60)
|
|
pred_dir = output_dir / 'predictions'
|
|
evaluate_test_set(
|
|
data_root=args.data_root,
|
|
test_file=args.test_file,
|
|
pred_dir=str(pred_dir),
|
|
output_dir=str(output_dir)
|
|
)
|
|
|
|
# Visualize
|
|
if not args.skip_visualization:
|
|
print("\n" + "="*60)
|
|
print("Creating visualizations...")
|
|
print("="*60)
|
|
pred_dir = output_dir / 'predictions'
|
|
visualize_test_set(
|
|
data_root=args.data_root,
|
|
test_file=args.test_file,
|
|
pred_dir=str(pred_dir),
|
|
output_dir=str(output_dir),
|
|
num_vis=args.num_vis if not args.vis_all else None
|
|
)
|
|
|
|
print("\n" + "="*60)
|
|
print("Evaluation complete!")
|
|
print(f"Results saved to: {output_dir}")
|
|
print("="*60)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|