from __future__ import annotations import tomllib from pathlib import Path from typing import Any, Dict, List from .config import TaskConfig, TaskStepConfig def load_task_from_toml(path: str | Path) -> TaskConfig: """ Load a TaskConfig from a TOML file. """ data = tomllib.loads(Path(path).read_text(encoding="utf-8")) task_data = data.get("task", {}) steps_data: List[Dict[str, Any]] = data.get("steps", []) steps = [ TaskStepConfig( kind=step["kind"], dataset_split=step.get("dataset_split"), dataset_split_file=step.get("dataset_split_file"), limit=step.get("limit"), eval_split=step.get("eval_split"), eval_split_file=step.get("eval_split_file"), params=step.get("params", {}), ) for step in steps_data ] return TaskConfig( name=task_data["name"], description=task_data.get("description", ""), project_config_name=task_data["project_config_name"], model_key=task_data.get("model_key", "sam2"), steps=steps, dataset_overrides=task_data.get("dataset_overrides", {}), model_overrides=task_data.get("model_overrides", {}), training_overrides=task_data.get("training_overrides", {}), evaluation_overrides=task_data.get("evaluation_overrides", {}), visualization_overrides=task_data.get("visualization_overrides", {}), )