2025-12-24 13:43:34 +08:00

41 lines
1.4 KiB
Python

from __future__ import annotations
import tomllib
from pathlib import Path
from typing import Any, Dict, List
from .config import TaskConfig, TaskStepConfig
def load_task_from_toml(path: str | Path) -> TaskConfig:
"""
Load a TaskConfig from a TOML file.
"""
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
task_data = data.get("task", {})
steps_data: List[Dict[str, Any]] = data.get("steps", [])
steps = [
TaskStepConfig(
kind=step["kind"],
dataset_split=step.get("dataset_split"),
dataset_split_file=step.get("dataset_split_file"),
limit=step.get("limit"),
eval_split=step.get("eval_split"),
eval_split_file=step.get("eval_split_file"),
params=step.get("params", {}),
)
for step in steps_data
]
return TaskConfig(
name=task_data["name"],
description=task_data.get("description", ""),
project_config_name=task_data["project_config_name"],
model_key=task_data.get("model_key", "sam2"),
steps=steps,
dataset_overrides=task_data.get("dataset_overrides", {}),
model_overrides=task_data.get("model_overrides", {}),
training_overrides=task_data.get("training_overrides", {}),
evaluation_overrides=task_data.get("evaluation_overrides", {}),
visualization_overrides=task_data.get("visualization_overrides", {}),
)