Training Reference
Import path
from phylognn.training import TrainingConfig, Trainer, load_training_config
Configuration and trainer
- class TrainingConfig
Dataclass controlling epochs, batch size, learning rate, optimizer, scheduler, device, checkpoint directory, early stopping, gradient clipping, and loader options.
- validate()
Raise
ValueErrorwhen a setting is outside the supported range.
- class Trainer(model, config, loss_fn=None, metrics=None, tracking_config=None, tracking_metadata=None, tracker=None)
Local trainer for PyG datasets and loaders. Samples must expose targets as
data.y.- fit(train_loader=None, val_loader=None, train_dataset=None, val_dataset=None)
Train the model and update local history and checkpoints.
- predict(loader=None, dataset=None, return_targets=False)
Run inference from a loader or dataset.
- save_checkpoint(filename)
Save training state under
config.save_dir.
- load_checkpoint(filename)
Restore model, optimizer, scheduler, and history state from a trusted checkpoint file.
- save_history(filename='history.json')
Write local training history as JSON.
- exception TrainingConfigError
Raised when a
TOMLtraining configuration cannot be parsed or validated.
- class ConfiguredTrainingSetup
Effective setup returned by
load_training_config(): model, training config, loss function, metrics, tracking config, and sanitized tracking metadata.
Datasets and splits
- class DatasetSplit
Deterministic train, validation, and test split definition.
- classmethod from_ratios(sample_ids, train_ratio, val_ratio, test_ratio, seed=42, shuffle=True)
Build a split from ratios.
- classmethod from_dict(splits)
Build a split from explicit sample IDs.
- classmethod from_manifest_dir(manifest_dir)
Load
train.txt,val.txt, andtest.txtmanifests.
- class SplitDatasetView
Lightweight view over a split-specific subset of a base dataset.
- class SplitPhyloDataset(data_list, labels=None, sample_ids=None, transform=None, pre_transform=None)
In-memory dataset for graph objects and labels.
- class SplitPhyloDiskDataset(graph_dir, label_dir=None, sample_ids=None, recursive=False, cache=True, transform=None, pre_transform=None)
Disk-backed dataset for trusted mirrored graph and label
.ptfiles.
TOML helpers
- load_training_config(path, *, model_overrides=None, training_overrides=None, loss=None, metrics=None)
Load a TOML file and return
ConfiguredTrainingSetup.
- create_trainer_from_config(path, *, model_overrides=None, training_overrides=None, loss=None, metrics=None)
Create a
Trainerfrom a TOML configuration.
- create_default_trainer(model, **kwargs)
Create a trainer with default local settings.
Metrics
Trainer metrics are configured with supported string keys or direct
torchmetrics.Metric instances. Built-in keys are:
Key |
Metric |
|---|---|
|
Mean squared error |
|
Mean absolute error |
|
Root mean squared error |
|
R-squared score for scalar outputs by default |
|
Mean absolute percentage error |
The legacy helper functions mse_metric, mae_metric, rmse_metric,
r2_metric, and relative_error_metric are no longer part of the public API.
Use the keys above or instantiate TorchMetrics classes directly.
Tracking
- class TrackingConfig
Optional experiment tracking settings. Disabled tracking imports no external backend.
- class TrackingRunInfo
External run identity returned after tracking starts.
- exception TrackingError
Raised when tracking configuration or logging fails.
- class WandbTracker(tracking_config)
Weights & Biases adapter with lazy
wandbimport.
Configuration sections and outputs
TOML files use [model], [model.params], [training], optional [loss],
optional [metrics], and optional [tracking]. Training writes local
checkpoints and history.json; enabled tracking logs sanitized run metadata
and metrics externally.