sam_crack/configs/lora_configs.py
2025-12-24 17:15:36 +08:00

91 lines
2.9 KiB
Python

"""
LoRA configuration strategies for SAM2 fine-tuning
Defines three strategies: Decoder-only, Decoder+Encoder, Full Model
"""
from dataclasses import dataclass
from typing import List
@dataclass
class LoRAConfig:
"""LoRA configuration for a specific strategy"""
name: str
description: str
target_modules: List[str] # Regex patterns for module names
r: int # LoRA rank
lora_alpha: int # Scaling factor
lora_dropout: float
# Strategy A: Decoder-Only (Fast Training)
STRATEGY_A = LoRAConfig(
name="decoder_only",
description="Fast training, decoder adaptation only (~2-3M params, 2-3 hours)",
target_modules=[
# Mask decoder transformer attention layers
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
],
r=8,
lora_alpha=16,
lora_dropout=0.05
)
# Strategy B: Decoder + Encoder (Balanced) [RECOMMENDED]
STRATEGY_B = LoRAConfig(
name="decoder_encoder",
description="Balanced training, best performance/cost ratio (~5-7M params, 4-6 hours)",
target_modules=[
# Mask decoder attention (from Strategy A)
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
# Image encoder attention
r"image_encoder\.trunk\.blocks\.\d+\.attn",
],
r=16,
lora_alpha=32,
lora_dropout=0.05
)
# Strategy C: Full Model (Maximum Adaptation)
STRATEGY_C = LoRAConfig(
name="full_model",
description="Maximum adaptation, highest compute cost (~10-15M params, 8-12 hours)",
target_modules=[
# All attention from Strategy B
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
r"image_encoder\.trunk\.blocks\.\d+\.attn",
# FFN layers
r"sam_mask_decoder\.transformer\.layers\.\d+\.mlp",
r"image_encoder\.trunk\.blocks\.\d+\.mlp",
],
r=32,
lora_alpha=64,
lora_dropout=0.1
)
# Strategy registry
STRATEGIES = {
'A': STRATEGY_A,
'B': STRATEGY_B,
'C': STRATEGY_C
}
def get_strategy(strategy_name: str) -> LoRAConfig:
"""Get LoRA strategy by name"""
if strategy_name not in STRATEGIES:
raise ValueError(f"Unknown strategy: {strategy_name}. Choose from {list(STRATEGIES.keys())}")
return STRATEGIES[strategy_name]