91 lines
2.9 KiB
Python
91 lines
2.9 KiB
Python
"""
|
|
LoRA configuration strategies for SAM2 fine-tuning
|
|
Defines three strategies: Decoder-only, Decoder+Encoder, Full Model
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
|
|
@dataclass
|
|
class LoRAConfig:
|
|
"""LoRA configuration for a specific strategy"""
|
|
name: str
|
|
description: str
|
|
target_modules: List[str] # Regex patterns for module names
|
|
r: int # LoRA rank
|
|
lora_alpha: int # Scaling factor
|
|
lora_dropout: float
|
|
|
|
|
|
# Strategy A: Decoder-Only (Fast Training)
|
|
STRATEGY_A = LoRAConfig(
|
|
name="decoder_only",
|
|
description="Fast training, decoder adaptation only (~2-3M params, 2-3 hours)",
|
|
target_modules=[
|
|
# Mask decoder transformer attention layers
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
|
|
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
|
|
],
|
|
r=8,
|
|
lora_alpha=16,
|
|
lora_dropout=0.05
|
|
)
|
|
|
|
|
|
# Strategy B: Decoder + Encoder (Balanced) [RECOMMENDED]
|
|
STRATEGY_B = LoRAConfig(
|
|
name="decoder_encoder",
|
|
description="Balanced training, best performance/cost ratio (~5-7M params, 4-6 hours)",
|
|
target_modules=[
|
|
# Mask decoder attention (from Strategy A)
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
|
|
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
|
|
# Image encoder attention
|
|
r"image_encoder\.trunk\.blocks\.\d+\.attn",
|
|
],
|
|
r=16,
|
|
lora_alpha=32,
|
|
lora_dropout=0.05
|
|
)
|
|
|
|
|
|
# Strategy C: Full Model (Maximum Adaptation)
|
|
STRATEGY_C = LoRAConfig(
|
|
name="full_model",
|
|
description="Maximum adaptation, highest compute cost (~10-15M params, 8-12 hours)",
|
|
target_modules=[
|
|
# All attention from Strategy B
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.self_attn",
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_token_to_image",
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.cross_attn_image_to_token",
|
|
r"sam_mask_decoder\.transformer\.final_attn_token_to_image",
|
|
r"image_encoder\.trunk\.blocks\.\d+\.attn",
|
|
# FFN layers
|
|
r"sam_mask_decoder\.transformer\.layers\.\d+\.mlp",
|
|
r"image_encoder\.trunk\.blocks\.\d+\.mlp",
|
|
],
|
|
r=32,
|
|
lora_alpha=64,
|
|
lora_dropout=0.1
|
|
)
|
|
|
|
|
|
# Strategy registry
|
|
STRATEGIES = {
|
|
'A': STRATEGY_A,
|
|
'B': STRATEGY_B,
|
|
'C': STRATEGY_C
|
|
}
|
|
|
|
|
|
def get_strategy(strategy_name: str) -> LoRAConfig:
|
|
"""Get LoRA strategy by name"""
|
|
if strategy_name not in STRATEGIES:
|
|
raise ValueError(f"Unknown strategy: {strategy_name}. Choose from {list(STRATEGIES.keys())}")
|
|
return STRATEGIES[strategy_name]
|