Experimental scGPT Cell-Type Annotation#

Who this notebook is for#

Researchers who want to compare a classical baseline, frozen scGPT embeddings, head-only fine-tuning, and LoRA fine-tuning on a labeled human single-cell RNA dataset.

Prerequisites#

  • python -m pip install "scdlkit[foundation,tutorials]"

  • enough free disk space for the official whole-human checkpoint cache

  • a basic Scanpy workflow is already familiar

What you will learn#

  • how to prepare labeled PBMC data for scGPT

  • how to compare PCA + logistic regression, frozen scGPT, head-only tuning, and LoRA tuning

  • how to inspect metrics, UMAPs, and a confusion matrix for the best strategy

What is out of scope#

  • full-model scGPT fine-tuning

  • non-human or multimodal workflows

  • treating this as a production annotation pipeline

Outline#

  1. load PBMC data and choose a quickstart or full profile

  2. prepare tokenized scGPT data and deterministic train/validation/test splits

  3. run the classical baseline and three scGPT strategies

  4. compare metrics and inspect the tuned embedding geometry

  5. save tutorial artifacts for the docs and quality pipeline

Next step#

  • easiest public wrapper path: examples/scgpt_dataset_specific_annotation.ipynb

  • API pages: docs/api/foundation.md and docs/api/annotation.md

Published tutorial status

This page is a static notebook copy published for documentation review. It is meant to show the exact workflow and outputs from the last recorded run.

  • Last run date (UTC): 2026-03-27 09:22 UTC

  • Publication mode: static executed tutorial

  • Execution profile: published

  • Artifact check in this sync: passed

  • Source notebook: examples/scgpt_cell_type_annotation.ipynb

from __future__ import annotations

from pathlib import Path
from time import perf_counter

import numpy as np
import pandas as pd
import scanpy as sc
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression

from scdlkit.evaluation import evaluate_predictions, save_markdown_report, save_metrics_table
from scdlkit.evaluation.metrics import representation_metrics
from scdlkit.foundation import (
    ScGPTLoRAConfig,
    load_scgpt_annotation_model,
    load_scgpt_model,
    prepare_scgpt_data,
    split_scgpt_data,
)
from scdlkit.training import Trainer
from scdlkit.visualization.classification import plot_confusion_matrix

Configuration#

Use the default quickstart profile for a CPU-friendly docs run. This quickstart mirrors the small foundation annotation smoke configuration so docs and CI exercise the same experimental path. Switch to full if you want to fine-tune on a larger built-in PBMC subset locally.

PROFILE = 'quickstart'

CONFIGS = {
    'quickstart': {
        'max_cells': 8,
        'max_genes': 32,
        'batch_size': 64,
        'head_epochs': 1,
        'lora_epochs': 1,
        'head_lr': 5e-3,
        'lora_lr': 2e-3,
    },
    'full': {
        'max_cells': 128,
        'max_genes': 192,
        'batch_size': 32,
        'head_epochs': 3,
        'lora_epochs': 2,
        'head_lr': 5e-3,
        'lora_lr': 2e-3,
    },
}

if PROFILE not in CONFIGS:
    raise ValueError(f'Unsupported PROFILE={PROFILE!r}. Expected one of {tuple(CONFIGS)}.')

config = CONFIGS[PROFILE]
output_dir = Path('artifacts/scgpt_cell_type_annotation')
output_dir.mkdir(parents=True, exist_ok=True)
SEED = 42

Load PBMC data#

This tutorial uses scanpy.datasets.pbmc3k_processed() because the goal here is model comparison and adaptation, not raw-count preprocessing. For raw QC and preprocessing context, use the official Scanpy PBMC tutorials and then bring the processed matrix into scDLKit.

The quickstart profile also trims the matched gene set so the experimental fine-tuning path stays responsive in docs and CI.

def stratified_subset_adata(adata, *, label_key: str, max_cells: int | None, seed: int):
    if max_cells is None or adata.n_obs <= max_cells:
        return adata.copy()
    labels = adata.obs[label_key].astype(str).to_numpy()
    indices = np.arange(adata.n_obs)
    rng = np.random.default_rng(seed)
    sampled = []
    for label in np.unique(labels):
        label_indices = indices[labels == label]
        take = max(1, int(round(max_cells * (len(label_indices) / len(indices)))))
        sampled.extend(rng.choice(label_indices, size=min(take, len(label_indices)), replace=False))
    sampled = np.array(sorted({int(index) for index in sampled}), dtype=int)
    if len(sampled) > max_cells:
        sampled = np.sort(rng.choice(sampled, size=max_cells, replace=False))
    return adata[sampled].copy()

adata_full = sc.datasets.pbmc3k_processed()
adata = stratified_subset_adata(
    adata_full,
    label_key='louvain',
    max_cells=config['max_cells'],
    seed=SEED,
)
def subset_high_variance_genes(adata, *, max_genes: int | None):
    if max_genes is None or adata.n_vars <= max_genes:
        return adata.copy()
    source = adata.raw.to_adata() if adata.raw is not None else adata.copy()
    matrix = source.X.toarray() if hasattr(source.X, 'toarray') else np.asarray(source.X)
    variances = np.var(matrix, axis=0)
    keep_indices = np.argsort(variances)[-max_genes:]
    subset = source[:, np.sort(keep_indices)].copy()
    subset.raw = subset.copy()
    return subset

adata = subset_high_variance_genes(adata, max_genes=config['max_genes'])
adata.shape
(8, 32)

Prepare scGPT data and splits#

The scGPT path uses a tokenized preparation pipeline instead of the standard prepare_data(...) matrix path. The split helper preserves label metadata so the confusion matrix can be labeled consistently.

prepared = prepare_scgpt_data(
    adata,
    checkpoint='whole-human',
    label_key='louvain',
    batch_size=config['batch_size'],
    use_raw=True,
    min_gene_overlap=16 if PROFILE == 'quickstart' else 96,
)
split = split_scgpt_data(prepared, val_size=0.15, test_size=0.15, random_state=SEED)
label_categories = list(prepared.label_categories or [])
label_categories
['CD4 T cells',
 'CD8 T cells',
 'Dendritic cells',
 'FCGR3A+ Monocytes',
 'Megakaryocytes',
 'NK cells']

Helpers#

These helpers keep the notebook readable while preserving the exact comparison story: classical baseline, frozen scGPT, head-only tuning, and LoRA tuning.

def count_trainable_parameters(model) -> int:
    return int(
        sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
    )


def subset_adata_from_dataset(adata, dataset):
    from torch.utils.data import Subset

    if isinstance(dataset, Subset):
        return adata[np.asarray(dataset.indices, dtype=int)].copy()
    return adata.copy()


def dataset_indices(dataset, total_obs: int):
    from torch.utils.data import Subset

    if isinstance(dataset, Subset):
        return np.asarray(dataset.indices, dtype=int)
    return np.arange(total_obs, dtype=int)


def expand_probabilities(probabilities, classes, *, num_classes: int):
    expanded = np.zeros((probabilities.shape[0], num_classes), dtype='float32')
    expanded[:, np.asarray(classes, dtype=int)] = np.asarray(probabilities, dtype='float32')
    return expanded


def save_umap(adata_subset, latent, *, output_path: Path, seed: int):
    plot_adata = adata_subset.copy()
    latent_array = np.asarray(latent, dtype='float32')
    plot_adata.obsm['X_scgpt_annotation'] = latent_array
    n_obs = int(plot_adata.n_obs)

    if n_obs < 4:
        figure, axis = plt.subplots(figsize=(5, 4))
        x_values = latent_array[:, 0]
        y_values = (
            latent_array[:, 1]
            if latent_array.shape[1] > 1
            else np.zeros(n_obs, dtype='float32')
        )
        labels = plot_adata.obs['louvain'].astype(str).to_numpy()
        for label in sorted(set(labels)):
            mask = labels == label
            axis.scatter(x_values[mask], y_values[mask], label=label, s=48, alpha=0.9)
        axis.set_title('Latent view (UMAP fallback)')
        axis.set_xlabel('latent_1')
        axis.set_ylabel('latent_2')
        axis.legend(loc='best', fontsize=8, frameon=False)
        figure.savefig(output_path, dpi=150, bbox_inches='tight')
        plt.close(figure)
        return

    n_neighbors = max(2, min(5, n_obs - 1))
    sc.pp.neighbors(plot_adata, use_rep='X_scgpt_annotation', n_neighbors=n_neighbors)
    try:
        sc.tl.umap(plot_adata, random_state=seed, init_pos='random')
        figure = sc.pl.umap(plot_adata, color='louvain', return_fig=True, frameon=False)
    except TypeError:
        figure, axis = plt.subplots(figsize=(5, 4))
        x_values = latent_array[:, 0]
        y_values = (
            latent_array[:, 1]
            if latent_array.shape[1] > 1
            else np.zeros(n_obs, dtype='float32')
        )
        labels = plot_adata.obs['louvain'].astype(str).to_numpy()
        for label in sorted(set(labels)):
            mask = labels == label
            axis.scatter(x_values[mask], y_values[mask], label=label, s=48, alpha=0.9)
        axis.set_title('Latent view (UMAP fallback)')
        axis.set_xlabel('latent_1')
        axis.set_ylabel('latent_2')
        axis.legend(loc='best', fontsize=8, frameon=False)
    figure.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close(figure)


def run_pca_logistic_baseline(adata, split, *, label_key: str, seed: int):
    train_indices = dataset_indices(split.train, adata.n_obs)
    test_dataset = split.test or split.val or split.train
    test_indices = dataset_indices(test_dataset, adata.n_obs)
    labels = pd.Categorical(adata.obs[label_key].astype(str))
    x_train = (
        adata.X[train_indices].toarray()
        if hasattr(adata.X[train_indices], 'toarray')
        else np.asarray(adata.X[train_indices], dtype='float32')
    )
    x_test = (
        adata.X[test_indices].toarray()
        if hasattr(adata.X[test_indices], 'toarray')
        else np.asarray(adata.X[test_indices], dtype='float32')
    )
    y_train = labels.codes[train_indices]
    y_test = labels.codes[test_indices]
    n_components = min(32, x_train.shape[0] - 1, x_train.shape[1])
    started_at = perf_counter()
    pca = PCA(n_components=n_components, random_state=seed)
    train_latent = pca.fit_transform(x_train)
    test_latent = pca.transform(x_test)
    classifier = LogisticRegression(max_iter=1000, random_state=seed)
    classifier.fit(train_latent, y_train)
    logits = expand_probabilities(
        classifier.predict_proba(test_latent),
        classifier.classes_,
        num_classes=len(labels.categories),
    )
    runtime_sec = perf_counter() - started_at
    predictions = {'logits': logits, 'y': y_test, 'latent': test_latent}
    metrics = evaluate_predictions('classification', predictions)
    metrics.update(representation_metrics(test_latent, y_test, None))
    return {
        'strategy': 'pca_logistic_annotation',
        'metrics': metrics,
        'latent': test_latent,
        'runtime_sec': runtime_sec,
        'trainable_parameters': 0,
        'class_names': [str(value) for value in labels.categories],
        'dataset': test_dataset,
    }


def run_frozen_probe(adata, prepared, split, *, seed: int):
    model = load_scgpt_model('whole-human', device='auto')
    trainer = Trainer(
        model=model,
        task='representation',
        batch_size=prepared.batch_size,
        device='auto',
        epochs=1,
    )
    started_at = perf_counter()
    train_predictions = trainer.predict_dataset(split.train)
    test_dataset = split.test or split.val or split.train
    test_predictions = trainer.predict_dataset(test_dataset)
    classifier = LogisticRegression(max_iter=1000, random_state=seed)
    classifier.fit(train_predictions['latent'], train_predictions['y'])
    logits = expand_probabilities(
        classifier.predict_proba(test_predictions['latent']),
        classifier.classes_,
        num_classes=len(label_categories),
    )
    runtime_sec = perf_counter() - started_at
    predictions = {
        'logits': logits,
        'y': test_predictions['y'],
        'latent': test_predictions['latent'],
    }
    metrics = evaluate_predictions('classification', predictions)
    metrics.update(representation_metrics(test_predictions['latent'], test_predictions['y'], None))
    return {
        'strategy': 'scgpt_frozen_probe',
        'metrics': metrics,
        'latent': test_predictions['latent'],
        'runtime_sec': runtime_sec,
        'trainable_parameters': 0,
        'class_names': label_categories,
        'dataset': test_dataset,
    }


def run_tuned_strategy(prepared, split, *, tuning_strategy: str, epochs: int, lr: float):
    model = load_scgpt_annotation_model(
        num_classes=len(label_categories),
        checkpoint='whole-human',
        tuning_strategy=tuning_strategy,
        label_categories=tuple(label_categories),
        device='auto',
        lora_config=(
            ScGPTLoRAConfig(rank=4, alpha=8.0, dropout=0.05, target_modules=('linear1', 'linear2'))
            if tuning_strategy == 'lora'
            else None
        ),
    )
    trainer = Trainer(
        model=model,
        task='classification',
        batch_size=prepared.batch_size,
        epochs=epochs,
        lr=lr,
        device='auto',
        early_stopping_patience=3,
        seed=SEED,
    )
    test_dataset = split.test or split.val or split.train
    started_at = perf_counter()
    trainer.fit(split.train, split.val)
    predictions = trainer.predict_dataset(test_dataset)
    runtime_sec = perf_counter() - started_at
    metrics = evaluate_predictions('classification', predictions)
    metrics.update(representation_metrics(predictions['latent'], predictions['y'], None))
    return {
        'strategy': f'scgpt_{tuning_strategy}',
        'metrics': metrics,
        'latent': predictions['latent'],
        'runtime_sec': runtime_sec,
        'trainable_parameters': count_trainable_parameters(model),
        'class_names': label_categories,
        'dataset': test_dataset,
    }

Run the comparison#

Interpretation target:

  • frozen probe: do I need fine-tuning at all?

  • head-only tuning: what is the cheapest trainable step up?

  • LoRA tuning: does parameter-efficient adaptation preserve performance without fully tuning the backbone?

pca_result = run_pca_logistic_baseline(adata, split, label_key='louvain', seed=SEED)
frozen_result = run_frozen_probe(adata, prepared, split, seed=SEED)
head_result = run_tuned_strategy(
    prepared,
    split,
    tuning_strategy='head',
    epochs=config['head_epochs'],
    lr=config['head_lr'],
)
lora_result = run_tuned_strategy(
    prepared,
    split,
    tuning_strategy='lora',
    epochs=config['lora_epochs'],
    lr=config['lora_lr'],
)

results = [pca_result, frozen_result, head_result, lora_result]
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
  warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
  warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
  warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
  warnings.warn("y_pred contains classes not in y_true")

Compare the strategies#

What correctness looks like in this tutorial:

  • the frozen probe should provide a meaningful baseline rather than a random classifier

  • head-only tuning should usually improve over the frozen probe on this labeled PBMC task

  • LoRA should be competitive with head-only tuning while keeping the trainable parameter count narrow

This notebook is still experimental. Treat the metric table as a reproducible comparison scaffold, not as a universal ranking for all single-cell datasets.

strategy_rows = []
for result in results:
    strategy_rows.append(
        {
            'strategy': result['strategy'],
            'accuracy': float(result['metrics']['accuracy']),
            'macro_f1': float(result['metrics']['macro_f1']),
            'silhouette': float(result['metrics'].get('silhouette', 0.0)),
            'runtime_sec': float(result['runtime_sec']),
            'trainable_parameters': int(result['trainable_parameters']),
        }
    )

strategy_frame = pd.DataFrame(strategy_rows).sort_values(['macro_f1', 'accuracy'], ascending=False)
strategy_frame.to_csv(output_dir / 'strategy_metrics.csv', index=False)
strategy_frame
strategy accuracy macro_f1 silhouette runtime_sec trainable_parameters
0 pca_logistic_annotation 0.5 0.333333 0.0 0.006513 0
1 scgpt_frozen_probe 0.5 0.333333 0.0 0.057630 0
2 scgpt_head 0.5 0.333333 0.0 1.224818 4102
3 scgpt_lora 0.0 0.000000 0.0 0.346037 102406

Save UMAPs and the best confusion matrix#

The two embedding figures are fixed by the release plan:

  • frozen scGPT embedding UMAP

  • LoRA embedding UMAP

The confusion matrix is saved for the best-performing strategy by macro F1.

frozen_test_adata = subset_adata_from_dataset(adata, frozen_result['dataset'])
lora_test_adata = subset_adata_from_dataset(adata, lora_result['dataset'])

frozen_umap_path = output_dir / 'frozen_embedding_umap.png'
lora_umap_path = output_dir / 'lora_embedding_umap.png'
save_umap(frozen_test_adata, frozen_result['latent'], output_path=frozen_umap_path, seed=SEED)
save_umap(lora_test_adata, lora_result['latent'], output_path=lora_umap_path, seed=SEED)

best_result = max(
    results,
    key=lambda item: (
        float(item['metrics']['macro_f1']),
        float(item['metrics']['accuracy']),
    ),
)
best_confusion_path = output_dir / 'best_strategy_confusion_matrix.png'
confusion_figure, _ = plot_confusion_matrix(
    best_result['metrics']['confusion_matrix'],
    class_names=best_result['class_names'],
)
confusion_figure.savefig(best_confusion_path, dpi=150, bbox_inches='tight')
plt.close(confusion_figure)

best_result['strategy'], best_confusion_path
('pca_logistic_annotation',
 PosixPath('artifacts/scgpt_cell_type_annotation/best_strategy_confusion_matrix.png'))

How to interpret the output#

  • If frozen scGPT is already strong, your dataset may only need a linear probe or a very small amount of adaptation.

  • If head-only tuning wins clearly, the dataset benefits from supervised adaptation even when the backbone stays frozen.

  • If LoRA matches or beats head-only tuning with a small trainable parameter count, that is a strong signal for parameter-efficient fine-tuning.

  • If all strategies cluster poorly or the confusion matrix is chaotic, inspect gene overlap, label quality, and whether the tutorial subset is too small for your local run.

best_metrics = {
    'profile': PROFILE,
    'subset_cells': int(adata.n_obs),
    'checkpoint': 'whole-human',
    'best_strategy': best_result['strategy'],
    'best_accuracy': float(best_result['metrics']['accuracy']),
    'best_macro_f1': float(best_result['metrics']['macro_f1']),
    'best_silhouette': float(best_result['metrics'].get('silhouette', 0.0)),
    'num_genes_matched': int(prepared.num_genes_matched),
}

save_metrics_table(best_metrics, output_dir / 'report.csv')
report_path = save_markdown_report(
    best_metrics,
    path=output_dir / 'report.md',
    title='Experimental scGPT cell-type annotation summary',
    extra_sections=[
        '',
        '## Interpretation notes',
        '',
        '- Built-in PBMC is the tutorial dataset, not the limit of the public API.',
        '- Frozen probe answers the baseline question: do I need fine-tuning at all?',
        '- Head-only tuning is the cheapest trainable path.',
        '- LoRA is the first parameter-efficient adaptation path in scDLKit.',
        '- Full backbone fine-tuning is intentionally deferred in this release line.',
    ],
)
report_path
PosixPath('artifacts/scgpt_cell_type_annotation/report.md')

Next tutorial#

output_paths = {
    'report_md': str(output_dir / 'report.md'),
    'report_csv': str(output_dir / 'report.csv'),
    'strategy_metrics_csv': str(output_dir / 'strategy_metrics.csv'),
    'frozen_embedding_umap': str(output_dir / 'frozen_embedding_umap.png'),
    'lora_embedding_umap': str(output_dir / 'lora_embedding_umap.png'),
    'best_strategy_confusion_matrix': str(output_dir / 'best_strategy_confusion_matrix.png'),
}
output_paths
{'report_md': 'artifacts/scgpt_cell_type_annotation/report.md',
 'report_csv': 'artifacts/scgpt_cell_type_annotation/report.csv',
 'strategy_metrics_csv': 'artifacts/scgpt_cell_type_annotation/strategy_metrics.csv',
 'frozen_embedding_umap': 'artifacts/scgpt_cell_type_annotation/frozen_embedding_umap.png',
 'lora_embedding_umap': 'artifacts/scgpt_cell_type_annotation/lora_embedding_umap.png',
 'best_strategy_confusion_matrix': 'artifacts/scgpt_cell_type_annotation/best_strategy_confusion_matrix.png'}