Custom Model Extension#

Audience:

  • Users who want to validate a raw PyTorch nn.Module inside the scDLKit workflow without registering it as a built-in model.

Prerequisites:

  • Install scdlkit[tutorials].

  • Understand the PBMC quickstart first.

  • Be comfortable reading small PyTorch modules.

Learning goals:

  • load PBMC data with Scanpy

  • define a small custom autoencoder in pure PyTorch

  • wrap it with wrap_reconstruction_module(...)

  • train it for the representation task through Trainer

  • push the learned latent space back into adata.obsm

  • continue with Scanpy neighbors and UMAP

Out of scope:

  • plugin packaging

  • full custom training loops outside scDLKit

  • proving the custom model is better than all built-in baselines

This is the supported custom-model path in v0.1.3. TaskRunner remains the built-in beginner API; custom wrapped modules are supported through Trainer first.

Related APIs:

  • Trainer: lower-level stable training loop

  • wrap_reconstruction_module(...): adapter path for raw PyTorch modules

Next steps:

  • Tutorial: classification_demo.ipynb or scgpt_cell_type_annotation.ipynb

  • API: docs/api/adapters.md

Published tutorial status

This page is a static notebook copy published for documentation review. It is meant to show the exact workflow and outputs from the last recorded run.

  • Last run date (UTC): 2026-03-27 09:25 UTC

  • Publication mode: static executed tutorial

  • Execution profile: published

  • Artifact check in this sync: passed

  • Source notebook: examples/custom_model_extension.ipynb

Install#

python -m pip install "scdlkit[tutorials]"

This notebook is CPU-friendly. If a CUDA build of PyTorch is installed, the same code path uses GPU automatically through device="auto".

from pathlib import Path

import matplotlib.pyplot as plt
import scanpy as sc
import torch
from torch import nn

from scdlkit import Trainer, prepare_data
from scdlkit.adapters import wrap_reconstruction_module
from scdlkit.data import transform_adata
from scdlkit.evaluation import evaluate_predictions, save_markdown_report, save_metrics_table
from scdlkit.visualization import plot_losses

OUTPUT_DIR = Path("artifacts/custom_model_extension")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
device_name = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device_name}")
Using device: cpu

Load PBMC data#

We reuse scanpy.datasets.pbmc3k_processed() so the downstream analysis flow stays familiar.

adata = sc.datasets.pbmc3k_processed()
adata
AnnData object with n_obs × n_vars = 2638 × 1838
    obs: 'n_genes', 'percent_mito', 'n_counts', 'louvain'
    var: 'n_cells'
    uns: 'draw_graph', 'louvain', 'louvain_colors', 'neighbors', 'pca', 'rank_genes_groups'
    obsm: 'X_pca', 'X_tsne', 'X_umap', 'X_draw_graph_fr'
    varm: 'PCs'
    obsp: 'distances', 'connectivities'

Prepare the splits#

For the adapter path we use the lower-level API directly: prepare_data(...), then Trainer(...).

prepared = prepare_data(adata, label_key="louvain")
prepared.input_dim
1838

Define a custom PyTorch model#

This is a plain nn.Module, not a built-in scDLKit model class.

class CustomAutoencoder(nn.Module):
    def __init__(self, input_dim: int, latent_dim: int = 16) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, input_dim),
        )

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        return self.encoder(x)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        latent = self.encode(x)
        return self.decoder(latent)


module = CustomAutoencoder(prepared.input_dim, latent_dim=16)
module
CustomAutoencoder(
  (encoder): Sequential(
    (0): Linear(in_features=1838, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=16, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=16, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1838, bias=True)
  )
)

Wrap the module and train it#

The wrapper gives the raw module the output and loss contract expected by scDLKit.

model = wrap_reconstruction_module(
    module,
    input_dim=prepared.input_dim,
    supported_tasks=("representation", "reconstruction"),
)

trainer = Trainer(
    model=model,
    task="representation",
    epochs=10,
    batch_size=128,
    device="auto",
    seed=42,
)
trainer.fit(prepared.train, prepared.val)
trainer.history_frame.tail()
epoch train_loss train_reconstruction_loss val_loss val_reconstruction_loss
5 6 0.816378 0.816378 0.837854 0.837854
6 7 0.810021 0.810021 0.837128 0.837128
7 8 0.806898 0.806898 0.836754 0.836754
8 9 0.804627 0.804627 0.836383 0.836383
9 10 0.801235 0.801235 0.836139 0.836139

Save a loss curve#

loss_fig, _ = plot_losses(trainer.history_frame)
loss_fig.savefig(OUTPUT_DIR / "loss_curve.png", dpi=150, bbox_inches="tight")
plt.close(loss_fig)
OUTPUT_DIR / "loss_curve.png"
PosixPath('artifacts/custom_model_extension/loss_curve.png')

Evaluate the held-out split#

split = prepared.test or prepared.val or prepared.train
predictions = trainer.predict_dataset(split)
metrics = evaluate_predictions("representation", predictions)
metrics
{'mse': 0.8269846439361572,
 'mae': 0.4079626798629761,
 'pearson': 0.21515628695487976,
 'spearman': 0.10855005603030946,
 'silhouette': 0.22205692529678345,
 'knn_label_consistency': 0.8787878787878788,
 'ari': 0.4655291916088742,
 'nmi': 0.6799320582228012}

Push the latent space back into Scanpy#

This is the important handoff: scDLKit trains the model, then Scanpy continues the downstream embedding workflow from adata.obsm.

full_split = transform_adata(
    adata,
    prepared.preprocessing,
    label_encoder=prepared.label_encoder,
    batch_encoder=prepared.batch_encoder,
)
full_predictions = trainer.predict_dataset(full_split)
adata.obsm["X_scdlkit_custom"] = full_predictions["latent"]

sc.pp.neighbors(adata, use_rep="X_scdlkit_custom")
sc.tl.umap(adata, random_state=42)

umap_fig = sc.pl.umap(adata, color="louvain", return_fig=True, frameon=False)
umap_fig.savefig(OUTPUT_DIR / "latent_umap.png", dpi=150, bbox_inches="tight")
plt.close(umap_fig)
OUTPUT_DIR / "latent_umap.png"
PosixPath('artifacts/custom_model_extension/latent_umap.png')

Save a simple report#

This uses the public low-level evaluation helpers rather than TaskRunner.

report_md = OUTPUT_DIR / "report.md"
report_csv = OUTPUT_DIR / "report.csv"

save_markdown_report(
    metrics,
    path=report_md,
    title="scDLKit custom model extension report",
    extra_sections=[
        "## Notes",
        "",
        "- Wrapped model type: `CustomAutoencoder`",
        "- Training surface: `Trainer`",
        "- Task: `representation`",
    ],
)
save_metrics_table(metrics, report_csv)

report_md, report_csv
(PosixPath('artifacts/custom_model_extension/report.md'),
 PosixPath('artifacts/custom_model_extension/report.csv'))

Expected outputs#

This tutorial should leave four artifacts under artifacts/custom_model_extension/:

  • report.md

  • report.csv

  • loss_curve.png

  • latent_umap.png

Validation checklist:

  1. The loss curve should decrease.

  2. The latent UMAP should show broad structure rather than collapse.

  3. The saved report should make it easy to compare your wrapped model against the built-in baselines.

Recommended next tutorials:

  • PBMC model comparison to benchmark against built-in baselines

  • reconstruction sanity check if your custom model also exposes reconstructed expression

output_paths = {
    "report_md": str(report_md),
    "report_csv": str(report_csv),
    "loss_curve_png": str(OUTPUT_DIR / "loss_curve.png"),
    "latent_umap_png": str(OUTPUT_DIR / "latent_umap.png"),
}
output_paths
{'report_md': 'artifacts/custom_model_extension/report.md',
 'report_csv': 'artifacts/custom_model_extension/report.csv',
 'loss_curve_png': 'artifacts/custom_model_extension/loss_curve.png',
 'latent_umap_png': 'artifacts/custom_model_extension/latent_umap.png'}