TOML Training Configuration

Script: examples/toml_training_config.py. Configuration: examples/toml_training_config.toml.

Inputs

  • examples/toml_training_config.toml with [model], [model.params], [training], [loss], [metrics], and [tracking] sections.

  • A deterministic in-memory graph dataset created by the script from small ete3.Tree objects.

  • Feature order ("node_time", "time_bin", "branch_length", "is_tip").

Run command

Run the script from the repository root:

python examples/toml_training_config.py

The script creates a Trainer through create_trainer_from_config(), builds train and validation splits, and calls Trainer.fit().

Expected output

The script prints stable markers for smoke tests and writes the checkpoint and history files used by the complete pipeline:

TOML training run summary
configured model: GATBiLSTMNet
checkpoint: example_outputs/toml_training_config/final_model.pt
history: example_outputs/toml_training_config/history.json

Stable stdout markers include:

TOML training run summary
configured model: GATBiLSTMNet
metrics: mse, rmse
checkpoint: example_outputs/toml_training_config/final_model.pt
history: example_outputs/toml_training_config/history.json

Files written

  • example_outputs/toml_training_config/final_model.pt

  • example_outputs/toml_training_config/history.json

Failure modes

  • Missing or malformed TOML raises TrainingConfigError from the training configuration loader.

  • Invalid model or trainer keys fail during configuration validation before training starts.

  • Output files are regenerated on each run, so stale files can be removed by deleting example_outputs/toml_training_config/.

Optional dependencies

The default [tracking] section keeps experiment tracking disabled. Install and configure the wandb extra only when you intentionally enable tracking; this local example does not require credentials.

Source

"""TOML-backed training configuration example with a tiny local dataset."""

from pathlib import Path
import shutil

import torch
from ete3 import Tree

from phylognn import TreeFeatureEngineer, TreeToGraphConverter
from phylognn.training import (
    DatasetSplit,
    SplitPhyloDataset,
    create_trainer_from_config,
)

ROOT = Path(__file__).resolve().parents[1]
CONFIG_PATH = ROOT / "examples" / "toml_training_config.toml"
OUTPUT_DIR = ROOT / "example_outputs" / "toml_training_config"
FEATURE_NAMES = ("node_time", "time_bin", "branch_length", "is_tip")


def _build_tree(scale: float) -> Tree:
    return Tree(
        f"((A:{0.8 * scale:.2f},B:{1.1 * scale:.2f})C:{0.4 * scale:.2f},"
        f"D:{1.3 * scale:.2f})root:0.0;",
        format=1,
    )


def _make_graph(scale: float, index: int):
    engineer = TreeFeatureEngineer(num_time_bins=6)
    tree = engineer.add_features(
        _build_tree(scale),
        origin_time=3.5 + scale,
        feature_names=FEATURE_NAMES,
        rescale=False,
        inplace=True,
    )
    converter = TreeToGraphConverter(
        feature_names=FEATURE_NAMES,
        add_virtual_nodes=False,
        append_is_virtual_feature=False,
        traversal_strategy=engineer.traversal_strategy,
    )
    return converter.convert(tree, graph_attrs={"sample_id": f"toml_sample_{index:02d}"})


def _build_dataset() -> SplitPhyloDataset:
    scales = [0.80, 0.95, 1.10, 1.25, 1.40, 1.55, 1.70, 1.85]
    graphs = [_make_graph(scale, index) for index, scale in enumerate(scales)]
    labels = torch.tensor([[[scale]] for scale in scales], dtype=torch.float32)
    sample_ids = [graph.sample_id for graph in graphs]
    return SplitPhyloDataset(data_list=graphs, labels=labels, sample_ids=sample_ids)


def main() -> None:
    torch.manual_seed(7)
    shutil.rmtree(OUTPUT_DIR, ignore_errors=True)

    dataset = _build_dataset()
    split = DatasetSplit.from_ratios(
        sample_ids=dataset.sample_ids,
        train_ratio=0.75,
        val_ratio=0.25,
        test_ratio=0.0,
        seed=7,
    )
    subsets = dataset.build_subsets(split)
    trainer = create_trainer_from_config(
        CONFIG_PATH,
        training_overrides={"save_dir": str(OUTPUT_DIR), "verbose": False},
    )
    history = trainer.fit(
        train_dataset=subsets["train"],
        val_dataset=subsets["val"],
    )

    print("TOML training run summary")
    print(f"configured model: {trainer.model.__class__.__name__}")
    print(f"feature order: {FEATURE_NAMES}")
    print(f"dataset sizes: train={len(subsets['train'])}, val={len(subsets['val'])}")
    print(f"epochs: {trainer.config.epochs}")
    print(f"batch_size: {trainer.config.batch_size}")
    print(
        "final losses: " f"train={history['train_loss'][-1]:.4f}, val={history['val_loss'][-1]:.4f}"
    )
    print(f"metrics: {', '.join(trainer.metrics)}")
    print(f"tracking: {'enabled' if trainer.tracking_config.enabled else 'disabled'}")
    print(f"checkpoint: {(OUTPUT_DIR / 'final_model.pt').relative_to(ROOT)}")
    print(f"history: {(OUTPUT_DIR / 'history.json').relative_to(ROOT)}")


if __name__ == "__main__":
    main()
[model]
type = "GATBiLSTMNet"

[model.params]
input_dim = 4
output_dim = 1
use_preprocessing = false
temporal_mode = "none"
gat_hidden_dim = 4
gat_heads = 1
num_gat_layers = 1
dropout_prob = 0.0
graph_pool = "mean"
head_hidden_dim = 4

[training]
epochs = 2
batch_size = 4
learning_rate = 0.005
weight_decay = 0.0
scheduler = "none"
early_stopping_patience = 0
save_dir = "./example_outputs/toml_training_config"
verbose = false
train_shuffle = false

[loss]
name = "mse"

[metrics]
names = ["mse", "rmse"]

[tracking]
enabled = false