Complete Pipeline
Script: examples/complete_pipeline.py.
Inputs
A tiny in-memory
ete3.Treecreated inside the script.Feature order
("node_time", "time_bin", "branch_length", "is_tip").TOML model settings from
examples/toml_training_config.toml.Optional checkpoint
example_outputs/toml_training_config/final_model.pt.
Run command
Run the pipeline directly from the repository root:
python examples/complete_pipeline.py
The pipeline applies TreeFeatureEngineer.add_features(), converts the tree
with TreeToGraphConverter, creates a matching trainer from the TOML config,
loads the standard checkpoint when available, or creates a temporary checkpoint
for the same model before calling Trainer.predict().
Expected output
Stable stdout markers include:
Complete pipeline summary
checkpoint:
graph x shape:
prediction:
Files written
If example_outputs/toml_training_config/final_model.pt exists, the script
loads it. If it is missing, the script creates a temporary checkpoint
internally and removes it before exit.
Optional dependencies
None.
Failure modes
Invalid graph fields fail through the existing model and trainer validation paths. A missing standard checkpoint is handled internally with a temporary checkpoint.
Source
"""Complete tree-to-prediction workflow using the TOML training checkpoint."""
from pathlib import Path
import sys
import tempfile
import torch
from ete3 import Tree
from phylognn import TreeFeatureEngineer, TreeToGraphConverter
from phylognn.training import create_trainer_from_config
ROOT = Path(__file__).resolve().parents[1]
CONFIG_PATH = ROOT / "examples" / "toml_training_config.toml"
OUTPUT_DIR = ROOT / "example_outputs" / "toml_training_config"
CHECKPOINT_PATH = OUTPUT_DIR / "final_model.pt"
FEATURE_NAMES = ("node_time", "time_bin", "branch_length", "is_tip")
def _build_tree() -> Tree:
return Tree("((A:0.92,B:1.18)C:0.42,D:1.36)root:0.0;", format=1)
def _build_graph():
engineer = TreeFeatureEngineer(num_time_bins=6)
tree = engineer.add_features(
_build_tree(),
origin_time=4.2,
feature_names=FEATURE_NAMES,
rescale=False,
inplace=True,
)
converter = TreeToGraphConverter(
feature_names=FEATURE_NAMES,
add_virtual_nodes=False,
append_is_virtual_feature=False,
traversal_strategy=engineer.traversal_strategy,
)
return converter.convert(tree, graph_attrs={"sample_id": "pipeline_tree"})
def _predict_with_checkpoint(graph, checkpoint_dir: Path, checkpoint_name: str) -> float:
trainer = create_trainer_from_config(
CONFIG_PATH,
training_overrides={"save_dir": str(checkpoint_dir), "verbose": False},
)
trainer.load_checkpoint(checkpoint_name)
prediction = trainer.predict([graph], batch_size=1)
return float(prediction.reshape(-1)[0].item())
def _predict_with_temporary_checkpoint(graph) -> tuple[float, Path]:
with tempfile.TemporaryDirectory(prefix="phylognn_complete_pipeline_") as temp_dir:
checkpoint_dir = Path(temp_dir)
trainer = create_trainer_from_config(
CONFIG_PATH,
training_overrides={"save_dir": str(checkpoint_dir), "verbose": False},
)
trainer.save_checkpoint("final_model.pt")
value = _predict_with_checkpoint(graph, checkpoint_dir, "final_model.pt")
return value, checkpoint_dir / "final_model.pt"
def main() -> None:
torch.manual_seed(7)
graph = _build_graph()
if CHECKPOINT_PATH.is_file():
value = _predict_with_checkpoint(graph, OUTPUT_DIR, "final_model.pt")
checkpoint_label = CHECKPOINT_PATH.relative_to(ROOT)
else:
value, checkpoint_path = _predict_with_temporary_checkpoint(graph)
checkpoint_label = checkpoint_path
print("Complete pipeline summary")
print(f"checkpoint: {checkpoint_label}")
print(f"graph x shape: {tuple(graph.x.shape)}")
print(f"graph edge_index shape: {tuple(graph.edge_index.shape)}")
print(f"prediction: {value:.4f}")
if __name__ == "__main__":
try:
main()
except SystemExit:
raise
except Exception as exc:
print(f"Complete pipeline failed: {exc}", file=sys.stderr)
raise