Experimental scGPT Cell-Type Annotation#
Who this notebook is for#
Researchers who want to compare a classical baseline, frozen scGPT embeddings, head-only fine-tuning, and LoRA fine-tuning on a labeled human single-cell RNA dataset.
Prerequisites#
python -m pip install "scdlkit[foundation,tutorials]"enough free disk space for the official
whole-humancheckpoint cachea basic Scanpy workflow is already familiar
What you will learn#
how to prepare labeled PBMC data for scGPT
how to compare
PCA + logistic regression, frozen scGPT, head-only tuning, and LoRA tuninghow to inspect metrics, UMAPs, and a confusion matrix for the best strategy
What is out of scope#
full-model scGPT fine-tuning
non-human or multimodal workflows
treating this as a production annotation pipeline
Outline#
load PBMC data and choose a quickstart or full profile
prepare tokenized scGPT data and deterministic train/validation/test splits
run the classical baseline and three scGPT strategies
compare metrics and inspect the tuned embedding geometry
save tutorial artifacts for the docs and quality pipeline
Next step#
easiest public wrapper path:
examples/scgpt_dataset_specific_annotation.ipynbAPI pages:
docs/api/foundation.mdanddocs/api/annotation.md
Published tutorial status
This page is a static notebook copy published for documentation review. It is meant to show the exact workflow and outputs from the last recorded run.
Last run date (UTC):
2026-03-27 09:22 UTCPublication mode:
static executed tutorialExecution profile:
publishedArtifact check in this sync:
passedSource notebook:
examples/scgpt_cell_type_annotation.ipynb
from __future__ import annotations
from pathlib import Path
from time import perf_counter
import numpy as np
import pandas as pd
import scanpy as sc
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from scdlkit.evaluation import evaluate_predictions, save_markdown_report, save_metrics_table
from scdlkit.evaluation.metrics import representation_metrics
from scdlkit.foundation import (
ScGPTLoRAConfig,
load_scgpt_annotation_model,
load_scgpt_model,
prepare_scgpt_data,
split_scgpt_data,
)
from scdlkit.training import Trainer
from scdlkit.visualization.classification import plot_confusion_matrix
Configuration#
Use the default quickstart profile for a CPU-friendly docs run. This quickstart mirrors the small foundation annotation smoke configuration so docs and CI exercise the same experimental path. Switch to full if you want to fine-tune on a larger built-in PBMC subset locally.
PROFILE = 'quickstart'
CONFIGS = {
'quickstart': {
'max_cells': 8,
'max_genes': 32,
'batch_size': 64,
'head_epochs': 1,
'lora_epochs': 1,
'head_lr': 5e-3,
'lora_lr': 2e-3,
},
'full': {
'max_cells': 128,
'max_genes': 192,
'batch_size': 32,
'head_epochs': 3,
'lora_epochs': 2,
'head_lr': 5e-3,
'lora_lr': 2e-3,
},
}
if PROFILE not in CONFIGS:
raise ValueError(f'Unsupported PROFILE={PROFILE!r}. Expected one of {tuple(CONFIGS)}.')
config = CONFIGS[PROFILE]
output_dir = Path('artifacts/scgpt_cell_type_annotation')
output_dir.mkdir(parents=True, exist_ok=True)
SEED = 42
Load PBMC data#
This tutorial uses scanpy.datasets.pbmc3k_processed() because the goal here is model comparison and adaptation, not raw-count preprocessing. For raw QC and preprocessing context, use the official Scanpy PBMC tutorials and then bring the processed matrix into scDLKit.
The quickstart profile also trims the matched gene set so the experimental fine-tuning path stays responsive in docs and CI.
def stratified_subset_adata(adata, *, label_key: str, max_cells: int | None, seed: int):
if max_cells is None or adata.n_obs <= max_cells:
return adata.copy()
labels = adata.obs[label_key].astype(str).to_numpy()
indices = np.arange(adata.n_obs)
rng = np.random.default_rng(seed)
sampled = []
for label in np.unique(labels):
label_indices = indices[labels == label]
take = max(1, int(round(max_cells * (len(label_indices) / len(indices)))))
sampled.extend(rng.choice(label_indices, size=min(take, len(label_indices)), replace=False))
sampled = np.array(sorted({int(index) for index in sampled}), dtype=int)
if len(sampled) > max_cells:
sampled = np.sort(rng.choice(sampled, size=max_cells, replace=False))
return adata[sampled].copy()
adata_full = sc.datasets.pbmc3k_processed()
adata = stratified_subset_adata(
adata_full,
label_key='louvain',
max_cells=config['max_cells'],
seed=SEED,
)
def subset_high_variance_genes(adata, *, max_genes: int | None):
if max_genes is None or adata.n_vars <= max_genes:
return adata.copy()
source = adata.raw.to_adata() if adata.raw is not None else adata.copy()
matrix = source.X.toarray() if hasattr(source.X, 'toarray') else np.asarray(source.X)
variances = np.var(matrix, axis=0)
keep_indices = np.argsort(variances)[-max_genes:]
subset = source[:, np.sort(keep_indices)].copy()
subset.raw = subset.copy()
return subset
adata = subset_high_variance_genes(adata, max_genes=config['max_genes'])
adata.shape
(8, 32)
Prepare scGPT data and splits#
The scGPT path uses a tokenized preparation pipeline instead of the standard prepare_data(...) matrix path. The split helper preserves label metadata so the confusion matrix can be labeled consistently.
prepared = prepare_scgpt_data(
adata,
checkpoint='whole-human',
label_key='louvain',
batch_size=config['batch_size'],
use_raw=True,
min_gene_overlap=16 if PROFILE == 'quickstart' else 96,
)
split = split_scgpt_data(prepared, val_size=0.15, test_size=0.15, random_state=SEED)
label_categories = list(prepared.label_categories or [])
label_categories
['CD4 T cells',
'CD8 T cells',
'Dendritic cells',
'FCGR3A+ Monocytes',
'Megakaryocytes',
'NK cells']
Helpers#
These helpers keep the notebook readable while preserving the exact comparison story: classical baseline, frozen scGPT, head-only tuning, and LoRA tuning.
def count_trainable_parameters(model) -> int:
return int(
sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad)
)
def subset_adata_from_dataset(adata, dataset):
from torch.utils.data import Subset
if isinstance(dataset, Subset):
return adata[np.asarray(dataset.indices, dtype=int)].copy()
return adata.copy()
def dataset_indices(dataset, total_obs: int):
from torch.utils.data import Subset
if isinstance(dataset, Subset):
return np.asarray(dataset.indices, dtype=int)
return np.arange(total_obs, dtype=int)
def expand_probabilities(probabilities, classes, *, num_classes: int):
expanded = np.zeros((probabilities.shape[0], num_classes), dtype='float32')
expanded[:, np.asarray(classes, dtype=int)] = np.asarray(probabilities, dtype='float32')
return expanded
def save_umap(adata_subset, latent, *, output_path: Path, seed: int):
plot_adata = adata_subset.copy()
latent_array = np.asarray(latent, dtype='float32')
plot_adata.obsm['X_scgpt_annotation'] = latent_array
n_obs = int(plot_adata.n_obs)
if n_obs < 4:
figure, axis = plt.subplots(figsize=(5, 4))
x_values = latent_array[:, 0]
y_values = (
latent_array[:, 1]
if latent_array.shape[1] > 1
else np.zeros(n_obs, dtype='float32')
)
labels = plot_adata.obs['louvain'].astype(str).to_numpy()
for label in sorted(set(labels)):
mask = labels == label
axis.scatter(x_values[mask], y_values[mask], label=label, s=48, alpha=0.9)
axis.set_title('Latent view (UMAP fallback)')
axis.set_xlabel('latent_1')
axis.set_ylabel('latent_2')
axis.legend(loc='best', fontsize=8, frameon=False)
figure.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(figure)
return
n_neighbors = max(2, min(5, n_obs - 1))
sc.pp.neighbors(plot_adata, use_rep='X_scgpt_annotation', n_neighbors=n_neighbors)
try:
sc.tl.umap(plot_adata, random_state=seed, init_pos='random')
figure = sc.pl.umap(plot_adata, color='louvain', return_fig=True, frameon=False)
except TypeError:
figure, axis = plt.subplots(figsize=(5, 4))
x_values = latent_array[:, 0]
y_values = (
latent_array[:, 1]
if latent_array.shape[1] > 1
else np.zeros(n_obs, dtype='float32')
)
labels = plot_adata.obs['louvain'].astype(str).to_numpy()
for label in sorted(set(labels)):
mask = labels == label
axis.scatter(x_values[mask], y_values[mask], label=label, s=48, alpha=0.9)
axis.set_title('Latent view (UMAP fallback)')
axis.set_xlabel('latent_1')
axis.set_ylabel('latent_2')
axis.legend(loc='best', fontsize=8, frameon=False)
figure.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(figure)
def run_pca_logistic_baseline(adata, split, *, label_key: str, seed: int):
train_indices = dataset_indices(split.train, adata.n_obs)
test_dataset = split.test or split.val or split.train
test_indices = dataset_indices(test_dataset, adata.n_obs)
labels = pd.Categorical(adata.obs[label_key].astype(str))
x_train = (
adata.X[train_indices].toarray()
if hasattr(adata.X[train_indices], 'toarray')
else np.asarray(adata.X[train_indices], dtype='float32')
)
x_test = (
adata.X[test_indices].toarray()
if hasattr(adata.X[test_indices], 'toarray')
else np.asarray(adata.X[test_indices], dtype='float32')
)
y_train = labels.codes[train_indices]
y_test = labels.codes[test_indices]
n_components = min(32, x_train.shape[0] - 1, x_train.shape[1])
started_at = perf_counter()
pca = PCA(n_components=n_components, random_state=seed)
train_latent = pca.fit_transform(x_train)
test_latent = pca.transform(x_test)
classifier = LogisticRegression(max_iter=1000, random_state=seed)
classifier.fit(train_latent, y_train)
logits = expand_probabilities(
classifier.predict_proba(test_latent),
classifier.classes_,
num_classes=len(labels.categories),
)
runtime_sec = perf_counter() - started_at
predictions = {'logits': logits, 'y': y_test, 'latent': test_latent}
metrics = evaluate_predictions('classification', predictions)
metrics.update(representation_metrics(test_latent, y_test, None))
return {
'strategy': 'pca_logistic_annotation',
'metrics': metrics,
'latent': test_latent,
'runtime_sec': runtime_sec,
'trainable_parameters': 0,
'class_names': [str(value) for value in labels.categories],
'dataset': test_dataset,
}
def run_frozen_probe(adata, prepared, split, *, seed: int):
model = load_scgpt_model('whole-human', device='auto')
trainer = Trainer(
model=model,
task='representation',
batch_size=prepared.batch_size,
device='auto',
epochs=1,
)
started_at = perf_counter()
train_predictions = trainer.predict_dataset(split.train)
test_dataset = split.test or split.val or split.train
test_predictions = trainer.predict_dataset(test_dataset)
classifier = LogisticRegression(max_iter=1000, random_state=seed)
classifier.fit(train_predictions['latent'], train_predictions['y'])
logits = expand_probabilities(
classifier.predict_proba(test_predictions['latent']),
classifier.classes_,
num_classes=len(label_categories),
)
runtime_sec = perf_counter() - started_at
predictions = {
'logits': logits,
'y': test_predictions['y'],
'latent': test_predictions['latent'],
}
metrics = evaluate_predictions('classification', predictions)
metrics.update(representation_metrics(test_predictions['latent'], test_predictions['y'], None))
return {
'strategy': 'scgpt_frozen_probe',
'metrics': metrics,
'latent': test_predictions['latent'],
'runtime_sec': runtime_sec,
'trainable_parameters': 0,
'class_names': label_categories,
'dataset': test_dataset,
}
def run_tuned_strategy(prepared, split, *, tuning_strategy: str, epochs: int, lr: float):
model = load_scgpt_annotation_model(
num_classes=len(label_categories),
checkpoint='whole-human',
tuning_strategy=tuning_strategy,
label_categories=tuple(label_categories),
device='auto',
lora_config=(
ScGPTLoRAConfig(rank=4, alpha=8.0, dropout=0.05, target_modules=('linear1', 'linear2'))
if tuning_strategy == 'lora'
else None
),
)
trainer = Trainer(
model=model,
task='classification',
batch_size=prepared.batch_size,
epochs=epochs,
lr=lr,
device='auto',
early_stopping_patience=3,
seed=SEED,
)
test_dataset = split.test or split.val or split.train
started_at = perf_counter()
trainer.fit(split.train, split.val)
predictions = trainer.predict_dataset(test_dataset)
runtime_sec = perf_counter() - started_at
metrics = evaluate_predictions('classification', predictions)
metrics.update(representation_metrics(predictions['latent'], predictions['y'], None))
return {
'strategy': f'scgpt_{tuning_strategy}',
'metrics': metrics,
'latent': predictions['latent'],
'runtime_sec': runtime_sec,
'trainable_parameters': count_trainable_parameters(model),
'class_names': label_categories,
'dataset': test_dataset,
}
Run the comparison#
Interpretation target:
frozen probe: do I need fine-tuning at all?
head-only tuning: what is the cheapest trainable step up?
LoRA tuning: does parameter-efficient adaptation preserve performance without fully tuning the backbone?
pca_result = run_pca_logistic_baseline(adata, split, label_key='louvain', seed=SEED)
frozen_result = run_frozen_probe(adata, prepared, split, seed=SEED)
head_result = run_tuned_strategy(
prepared,
split,
tuning_strategy='head',
epochs=config['head_epochs'],
lr=config['head_lr'],
)
lora_result = run_tuned_strategy(
prepared,
split,
tuning_strategy='lora',
epochs=config['lora_epochs'],
lr=config['lora_lr'],
)
results = [pca_result, frozen_result, head_result, lora_result]
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
warnings.warn("y_pred contains classes not in y_true")
/opt/hostedtoolcache/Python/3.11.15/x64/lib/python3.11/site-packages/sklearn/metrics/_classification.py:2924: UserWarning: y_pred contains classes not in y_true
warnings.warn("y_pred contains classes not in y_true")
Compare the strategies#
What correctness looks like in this tutorial:
the frozen probe should provide a meaningful baseline rather than a random classifier
head-only tuning should usually improve over the frozen probe on this labeled PBMC task
LoRA should be competitive with head-only tuning while keeping the trainable parameter count narrow
This notebook is still experimental. Treat the metric table as a reproducible comparison scaffold, not as a universal ranking for all single-cell datasets.
strategy_rows = []
for result in results:
strategy_rows.append(
{
'strategy': result['strategy'],
'accuracy': float(result['metrics']['accuracy']),
'macro_f1': float(result['metrics']['macro_f1']),
'silhouette': float(result['metrics'].get('silhouette', 0.0)),
'runtime_sec': float(result['runtime_sec']),
'trainable_parameters': int(result['trainable_parameters']),
}
)
strategy_frame = pd.DataFrame(strategy_rows).sort_values(['macro_f1', 'accuracy'], ascending=False)
strategy_frame.to_csv(output_dir / 'strategy_metrics.csv', index=False)
strategy_frame
| strategy | accuracy | macro_f1 | silhouette | runtime_sec | trainable_parameters | |
|---|---|---|---|---|---|---|
| 0 | pca_logistic_annotation | 0.5 | 0.333333 | 0.0 | 0.006513 | 0 |
| 1 | scgpt_frozen_probe | 0.5 | 0.333333 | 0.0 | 0.057630 | 0 |
| 2 | scgpt_head | 0.5 | 0.333333 | 0.0 | 1.224818 | 4102 |
| 3 | scgpt_lora | 0.0 | 0.000000 | 0.0 | 0.346037 | 102406 |
Save UMAPs and the best confusion matrix#
The two embedding figures are fixed by the release plan:
frozen scGPT embedding UMAP
LoRA embedding UMAP
The confusion matrix is saved for the best-performing strategy by macro F1.
frozen_test_adata = subset_adata_from_dataset(adata, frozen_result['dataset'])
lora_test_adata = subset_adata_from_dataset(adata, lora_result['dataset'])
frozen_umap_path = output_dir / 'frozen_embedding_umap.png'
lora_umap_path = output_dir / 'lora_embedding_umap.png'
save_umap(frozen_test_adata, frozen_result['latent'], output_path=frozen_umap_path, seed=SEED)
save_umap(lora_test_adata, lora_result['latent'], output_path=lora_umap_path, seed=SEED)
best_result = max(
results,
key=lambda item: (
float(item['metrics']['macro_f1']),
float(item['metrics']['accuracy']),
),
)
best_confusion_path = output_dir / 'best_strategy_confusion_matrix.png'
confusion_figure, _ = plot_confusion_matrix(
best_result['metrics']['confusion_matrix'],
class_names=best_result['class_names'],
)
confusion_figure.savefig(best_confusion_path, dpi=150, bbox_inches='tight')
plt.close(confusion_figure)
best_result['strategy'], best_confusion_path
('pca_logistic_annotation',
PosixPath('artifacts/scgpt_cell_type_annotation/best_strategy_confusion_matrix.png'))
How to interpret the output#
If frozen scGPT is already strong, your dataset may only need a linear probe or a very small amount of adaptation.
If head-only tuning wins clearly, the dataset benefits from supervised adaptation even when the backbone stays frozen.
If LoRA matches or beats head-only tuning with a small trainable parameter count, that is a strong signal for parameter-efficient fine-tuning.
If all strategies cluster poorly or the confusion matrix is chaotic, inspect gene overlap, label quality, and whether the tutorial subset is too small for your local run.
best_metrics = {
'profile': PROFILE,
'subset_cells': int(adata.n_obs),
'checkpoint': 'whole-human',
'best_strategy': best_result['strategy'],
'best_accuracy': float(best_result['metrics']['accuracy']),
'best_macro_f1': float(best_result['metrics']['macro_f1']),
'best_silhouette': float(best_result['metrics'].get('silhouette', 0.0)),
'num_genes_matched': int(prepared.num_genes_matched),
}
save_metrics_table(best_metrics, output_dir / 'report.csv')
report_path = save_markdown_report(
best_metrics,
path=output_dir / 'report.md',
title='Experimental scGPT cell-type annotation summary',
extra_sections=[
'',
'## Interpretation notes',
'',
'- Built-in PBMC is the tutorial dataset, not the limit of the public API.',
'- Frozen probe answers the baseline question: do I need fine-tuning at all?',
'- Head-only tuning is the cheapest trainable path.',
'- LoRA is the first parameter-efficient adaptation path in scDLKit.',
'- Full backbone fine-tuning is intentionally deferred in this release line.',
],
)
report_path
PosixPath('artifacts/scgpt_cell_type_annotation/report.md')
Next tutorial#
Continue with the general Downstream Scanpy after scDLKit notebook if you want a broader Scanpy interpretation pass.
Revisit the frozen Experimental scGPT PBMC embeddings notebook if you want to isolate the embedding-only workflow before training.
output_paths = {
'report_md': str(output_dir / 'report.md'),
'report_csv': str(output_dir / 'report.csv'),
'strategy_metrics_csv': str(output_dir / 'strategy_metrics.csv'),
'frozen_embedding_umap': str(output_dir / 'frozen_embedding_umap.png'),
'lora_embedding_umap': str(output_dir / 'lora_embedding_umap.png'),
'best_strategy_confusion_matrix': str(output_dir / 'best_strategy_confusion_matrix.png'),
}
output_paths
{'report_md': 'artifacts/scgpt_cell_type_annotation/report.md',
'report_csv': 'artifacts/scgpt_cell_type_annotation/report.csv',
'strategy_metrics_csv': 'artifacts/scgpt_cell_type_annotation/strategy_metrics.csv',
'frozen_embedding_umap': 'artifacts/scgpt_cell_type_annotation/frozen_embedding_umap.png',
'lora_embedding_umap': 'artifacts/scgpt_cell_type_annotation/lora_embedding_umap.png',
'best_strategy_confusion_matrix': 'artifacts/scgpt_cell_type_annotation/best_strategy_confusion_matrix.png'}