41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
from __future__ import annotations
|
|
|
|
import tomllib
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
|
|
from .config import TaskConfig, TaskStepConfig
|
|
|
|
|
|
def load_task_from_toml(path: str | Path) -> TaskConfig:
|
|
"""
|
|
Load a TaskConfig from a TOML file.
|
|
"""
|
|
data = tomllib.loads(Path(path).read_text(encoding="utf-8"))
|
|
task_data = data.get("task", {})
|
|
steps_data: List[Dict[str, Any]] = data.get("steps", [])
|
|
steps = [
|
|
TaskStepConfig(
|
|
kind=step["kind"],
|
|
dataset_split=step.get("dataset_split"),
|
|
dataset_split_file=step.get("dataset_split_file"),
|
|
limit=step.get("limit"),
|
|
eval_split=step.get("eval_split"),
|
|
eval_split_file=step.get("eval_split_file"),
|
|
params=step.get("params", {}),
|
|
)
|
|
for step in steps_data
|
|
]
|
|
return TaskConfig(
|
|
name=task_data["name"],
|
|
description=task_data.get("description", ""),
|
|
project_config_name=task_data["project_config_name"],
|
|
model_key=task_data.get("model_key", "sam2"),
|
|
steps=steps,
|
|
dataset_overrides=task_data.get("dataset_overrides", {}),
|
|
model_overrides=task_data.get("model_overrides", {}),
|
|
training_overrides=task_data.get("training_overrides", {}),
|
|
evaluation_overrides=task_data.get("evaluation_overrides", {}),
|
|
visualization_overrides=task_data.get("visualization_overrides", {}),
|
|
)
|