Custom Models#
v0.1.3 adds the first extensibility path for user-supplied PyTorch modules.
The supported surface is Trainer, not TaskRunner. TaskRunner remains the beginner-facing path for built-in scDLKit models. If you want to bring your own nn.Module, wrap it with an adapter and train it through the lower-level workflow.
What adapters do#
The adapter layer lets you:
keep your model as a normal
torch.nn.Modulereuse scDLKit data preparation and training loops
evaluate predictions with the same scDLKit metrics helpers
write the resulting latent space back into
adata.obsmand continue with Scanpy
This keeps scDLKit focused on rapid prototyping and validation without forcing every custom model into the built-in registry.
Current limitations#
The first adapter release is intentionally narrow:
wrapped models are supported through
Trainerfirstthe module contract is
x-only input for nowfull-batch callback plumbing is not part of this release
foundation-model integrations are still future work
That scope is deliberate. The first goal is to make custom small-model prototyping stable before adding larger foundation-model workflows.
Minimal reconstruction wrapper#
import torch
from torch import nn
from scdlkit import Trainer, prepare_data
from scdlkit.adapters import wrap_reconstruction_module
class SmallAutoencoder(nn.Module):
def __init__(self, input_dim: int, latent_dim: int = 16) -> None:
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
)
def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.decoder(self.encode(x))
prepared = prepare_data(adata, label_key="louvain")
wrapped = wrap_reconstruction_module(
SmallAutoencoder(prepared.input_dim),
input_dim=prepared.input_dim,
supported_tasks=("representation", "reconstruction"),
)
trainer = Trainer(
model=wrapped,
task="representation",
device="auto",
epochs=10,
batch_size=128,
)
trainer.fit(prepared.train, prepared.val)
Evaluation and Scanpy handoff#
from scdlkit.data import transform_adata
from scdlkit.evaluation import evaluate_predictions, save_markdown_report, save_metrics_table
test_predictions = trainer.predict_dataset(prepared.test or prepared.val)
metrics = evaluate_predictions("representation", test_predictions)
save_markdown_report(
metrics,
path="artifacts/custom_model_extension/report.md",
title="scDLKit custom model extension report",
)
save_metrics_table(metrics, "artifacts/custom_model_extension/report.csv")
full_split = transform_adata(
adata,
prepared.preprocessing,
label_encoder=prepared.label_encoder,
batch_encoder=prepared.batch_encoder,
)
full_predictions = trainer.predict_dataset(full_split)
adata.obsm["X_scdlkit_custom"] = full_predictions["latent"]
From there, continue with standard Scanpy steps:
import scanpy as sc
sc.pp.neighbors(adata, use_rep="X_scdlkit_custom")
sc.tl.umap(adata, random_state=42)
sc.pl.umap(adata, color="louvain")
Tutorial#
For the full end-to-end walkthrough, see the rendered notebook:
This notebook defines a small custom autoencoder directly in the tutorial, trains it through Trainer, evaluates the result, and saves artifacts under artifacts/custom_model_extension/.