Custom Model Extension#
Audience:
Users who want to validate a raw PyTorch
nn.Moduleinside the scDLKit workflow without registering it as a built-in model.
Prerequisites:
Install
scdlkit[tutorials].Understand the PBMC quickstart first.
Be comfortable reading small PyTorch modules.
Learning goals:
load PBMC data with Scanpy
define a small custom autoencoder in pure PyTorch
wrap it with
wrap_reconstruction_module(...)train it for the
representationtask throughTrainerpush the learned latent space back into
adata.obsmcontinue with Scanpy neighbors and UMAP
Out of scope:
plugin packaging
full custom training loops outside scDLKit
proving the custom model is better than all built-in baselines
This is the supported custom-model path in v0.1.3. TaskRunner remains the built-in beginner API; custom wrapped modules are supported through Trainer first.
Related APIs:
Trainer: lower-level stable training loopwrap_reconstruction_module(...): adapter path for raw PyTorch modules
Next steps:
Tutorial:
classification_demo.ipynborscgpt_cell_type_annotation.ipynbAPI:
docs/api/adapters.md
Published tutorial status
This page is a static notebook copy published for documentation review. It is meant to show the exact workflow and outputs from the last recorded run.
Last run date (UTC):
2026-03-27 09:25 UTCPublication mode:
static executed tutorialExecution profile:
publishedArtifact check in this sync:
passedSource notebook:
examples/custom_model_extension.ipynb
Install#
python -m pip install "scdlkit[tutorials]"
This notebook is CPU-friendly. If a CUDA build of PyTorch is installed, the same code path uses GPU automatically through device="auto".
from pathlib import Path
import matplotlib.pyplot as plt
import scanpy as sc
import torch
from torch import nn
from scdlkit import Trainer, prepare_data
from scdlkit.adapters import wrap_reconstruction_module
from scdlkit.data import transform_adata
from scdlkit.evaluation import evaluate_predictions, save_markdown_report, save_metrics_table
from scdlkit.visualization import plot_losses
OUTPUT_DIR = Path("artifacts/custom_model_extension")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
device_name = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device_name}")
Using device: cpu
Load PBMC data#
We reuse scanpy.datasets.pbmc3k_processed() so the downstream analysis flow stays familiar.
adata = sc.datasets.pbmc3k_processed()
adata
AnnData object with n_obs × n_vars = 2638 × 1838
obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
var: 'n_cells'
uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
varm: 'PCs'
obsp: 'distances', 'connectivities'
Prepare the splits#
For the adapter path we use the lower-level API directly: prepare_data(...), then Trainer(...).
prepared = prepare_data(adata, label_key="louvain")
prepared.input_dim
1838
Define a custom PyTorch model#
This is a plain nn.Module, not a built-in scDLKit model class.
class CustomAutoencoder(nn.Module):
def __init__(self, input_dim: int, latent_dim: int = 16) -> None:
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Linear(128, latent_dim),
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 128),
nn.ReLU(),
nn.Linear(128, input_dim),
)
def encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
def forward(self, x: torch.Tensor) -> torch.Tensor:
latent = self.encode(x)
return self.decoder(latent)
module = CustomAutoencoder(prepared.input_dim, latent_dim=16)
module
CustomAutoencoder(
(encoder): Sequential(
(0): Linear(in_features=1838, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=16, bias=True)
)
(decoder): Sequential(
(0): Linear(in_features=16, out_features=128, bias=True)
(1): ReLU()
(2): Linear(in_features=128, out_features=1838, bias=True)
)
)
Wrap the module and train it#
The wrapper gives the raw module the output and loss contract expected by scDLKit.
model = wrap_reconstruction_module(
module,
input_dim=prepared.input_dim,
supported_tasks=("representation", "reconstruction"),
)
trainer = Trainer(
model=model,
task="representation",
epochs=10,
batch_size=128,
device="auto",
seed=42,
)
trainer.fit(prepared.train, prepared.val)
trainer.history_frame.tail()
| epoch | train_loss | train_reconstruction_loss | val_loss | val_reconstruction_loss | |
|---|---|---|---|---|---|
| 5 | 6 | 0.816378 | 0.816378 | 0.837854 | 0.837854 |
| 6 | 7 | 0.810021 | 0.810021 | 0.837128 | 0.837128 |
| 7 | 8 | 0.806898 | 0.806898 | 0.836754 | 0.836754 |
| 8 | 9 | 0.804627 | 0.804627 | 0.836383 | 0.836383 |
| 9 | 10 | 0.801235 | 0.801235 | 0.836139 | 0.836139 |
Save a loss curve#
loss_fig, _ = plot_losses(trainer.history_frame)
loss_fig.savefig(OUTPUT_DIR / "loss_curve.png", dpi=150, bbox_inches="tight")
plt.close(loss_fig)
OUTPUT_DIR / "loss_curve.png"
PosixPath('artifacts/custom_model_extension/loss_curve.png')
Evaluate the held-out split#
split = prepared.test or prepared.val or prepared.train
predictions = trainer.predict_dataset(split)
metrics = evaluate_predictions("representation", predictions)
metrics
{'mse': 0.8269846439361572,
'mae': 0.4079626798629761,
'pearson': 0.21515628695487976,
'spearman': 0.10855005603030946,
'silhouette': 0.22205692529678345,
'knn_label_consistency': 0.8787878787878788,
'ari': 0.4655291916088742,
'nmi': 0.6799320582228012}
Push the latent space back into Scanpy#
This is the important handoff: scDLKit trains the model, then Scanpy continues the downstream embedding workflow from adata.obsm.
full_split = transform_adata(
adata,
prepared.preprocessing,
label_encoder=prepared.label_encoder,
batch_encoder=prepared.batch_encoder,
)
full_predictions = trainer.predict_dataset(full_split)
adata.obsm["X_scdlkit_custom"] = full_predictions["latent"]
sc.pp.neighbors(adata, use_rep="X_scdlkit_custom")
sc.tl.umap(adata, random_state=42)
umap_fig = sc.pl.umap(adata, color="louvain", return_fig=True, frameon=False)
umap_fig.savefig(OUTPUT_DIR / "latent_umap.png", dpi=150, bbox_inches="tight")
plt.close(umap_fig)
OUTPUT_DIR / "latent_umap.png"
PosixPath('artifacts/custom_model_extension/latent_umap.png')
Save a simple report#
This uses the public low-level evaluation helpers rather than TaskRunner.
report_md = OUTPUT_DIR / "report.md"
report_csv = OUTPUT_DIR / "report.csv"
save_markdown_report(
metrics,
path=report_md,
title="scDLKit custom model extension report",
extra_sections=[
"## Notes",
"",
"- Wrapped model type: `CustomAutoencoder`",
"- Training surface: `Trainer`",
"- Task: `representation`",
],
)
save_metrics_table(metrics, report_csv)
report_md, report_csv
(PosixPath('artifacts/custom_model_extension/report.md'),
PosixPath('artifacts/custom_model_extension/report.csv'))
Expected outputs#
This tutorial should leave four artifacts under artifacts/custom_model_extension/:
report.mdreport.csvloss_curve.pnglatent_umap.png
Validation checklist:
The loss curve should decrease.
The latent UMAP should show broad structure rather than collapse.
The saved report should make it easy to compare your wrapped model against the built-in baselines.
Recommended next tutorials:
PBMC model comparison to benchmark against built-in baselines
reconstruction sanity check if your custom model also exposes reconstructed expression
output_paths = {
"report_md": str(report_md),
"report_csv": str(report_csv),
"loss_curve_png": str(OUTPUT_DIR / "loss_curve.png"),
"latent_umap_png": str(OUTPUT_DIR / "latent_umap.png"),
}
output_paths
{'report_md': 'artifacts/custom_model_extension/report.md',
'report_csv': 'artifacts/custom_model_extension/report.csv',
'loss_curve_png': 'artifacts/custom_model_extension/loss_curve.png',
'latent_umap_png': 'artifacts/custom_model_extension/latent_umap.png'}