API reference#
The public API stays intentionally compact.
Top-level package#
Public package surface for scDLKit.
- class scdlkit.BaseModel(input_dim)#
Bases:
ModuleCommon base class for all registered models.
- Parameters:
input_dim (int)
- class scdlkit.BenchmarkResult(metrics_frame, runners, output_paths=<factory>)#
Bases:
objectCollected results from comparing multiple models.
- class scdlkit.PreparedData(train, val, test, input_dim, feature_names, label_encoder, batch_encoder, preprocessing)#
Bases:
objectPrepared train/validation/test splits and metadata.
- class scdlkit.TaskRunner(*, model, task, latent_dim=32, hidden_dims=(512, 256), epochs=50, batch_size=256, lr=0.001, device='auto', mixed_precision=False, early_stopping_patience=10, checkpoint=True, seed=42, layer='X', use_hvg=False, n_top_genes=2000, normalize=False, log1p=False, scale=False, label_key=None, batch_key=None, val_size=0.15, test_size=0.15, batch_aware_split=False, random_state=42, output_dir=None, model_kwargs=None)#
Bases:
objectBeginner-facing training, evaluation, and visualization workflow.
- Parameters:
task (str)
latent_dim (int)
epochs (int)
batch_size (int)
lr (float)
device (str)
mixed_precision (bool)
early_stopping_patience (int)
checkpoint (bool)
seed (int)
layer (str)
use_hvg (bool)
n_top_genes (int)
normalize (bool)
log1p (bool)
scale (bool)
label_key (str | None)
batch_key (str | None)
val_size (float)
test_size (float)
batch_aware_split (bool)
random_state (int)
output_dir (str | None)
- encode(adata)#
Encode new AnnData into latent representations.
- Return type:
ndarray- Parameters:
adata (AnnData)
- evaluate(adata=None)#
Evaluate on held-out test data or on a provided AnnData object.
- fit(adata, *, val_adata=None, test_adata=None)#
Prepare data, instantiate the model, and train it.
- Return type:
- Parameters:
- class scdlkit.Trainer(model, task, *, epochs=50, batch_size=256, lr=0.001, device='auto', mixed_precision=False, early_stopping_patience=10, checkpoint=True, seed=42)#
Bases:
objectTrain scDLKit models with a task adapter.
- Parameters:
- fit(train_data, val_data=None)#
Train the model and restore the best checkpointed state.
- property history_frame: DataFrame#
Training history as a DataFrame.
- predict_dataset(data)#
Run inference on a dataset and collect batched outputs.
- scdlkit.compare_models(adata, *, models, task, shared_kwargs=None, output_dir=None)#
Train and evaluate several models with shared configuration.
- scdlkit.create_model(name, **kwargs)#
Instantiate a registered model by name.
- scdlkit.prepare_data(adata, *, layer='X', use_hvg=False, n_top_genes=2000, normalize=False, log1p=False, scale=False, label_key=None, batch_key=None, val_size=0.15, test_size=0.15, batch_aware_split=False, random_state=42, copy=True)#
Prepare AnnData splits and preprocessing metadata.
Data preparation#
AnnData preparation and transformation utilities.
- scdlkit.data.prepare.prepare_data(adata, *, layer='X', use_hvg=False, n_top_genes=2000, normalize=False, log1p=False, scale=False, label_key=None, batch_key=None, val_size=0.15, test_size=0.15, batch_aware_split=False, random_state=42, copy=True)#
Prepare AnnData splits and preprocessing metadata.
- scdlkit.data.prepare.transform_adata(adata, preprocessing, *, label_encoder=None, batch_encoder=None, copy=True)#
Transform new AnnData using stored preprocessing metadata.
Training#
Plain PyTorch training loop.
- class scdlkit.training.trainer.Trainer(model, task, *, epochs=50, batch_size=256, lr=0.001, device='auto', mixed_precision=False, early_stopping_patience=10, checkpoint=True, seed=42)#
Bases:
objectTrain scDLKit models with a task adapter.
- Parameters:
- fit(train_data, val_data=None)#
Train the model and restore the best checkpointed state.
- property history_frame: DataFrame#
Training history as a DataFrame.
- predict_dataset(data)#
Run inference on a dataset and collect batched outputs.
Model registry and implementations#
Model registry helpers.
- scdlkit.models.registry.create_model(name, **kwargs)#
Instantiate a registered model by name.
- scdlkit.models.registry.register_model(name, *aliases)#
Register a model factory under one or more names.
MLP autoencoder baseline.
- class scdlkit.models.autoencoder.AutoEncoder(input_dim, latent_dim=32, hidden_dims=(512, 256), dropout=0.1)#
Bases:
BaseModelSimple MLP autoencoder.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Variational autoencoder baseline.
- class scdlkit.models.vae.VariationalAutoEncoder(input_dim, latent_dim=32, hidden_dims=(512, 256), dropout=0.1, kl_weight=1.0)#
Bases:
BaseModelVariational autoencoder with an MLP encoder/decoder.
- Parameters:
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Denoising autoencoder baseline.
- class scdlkit.models.denoising.DenoisingAutoEncoder(input_dim, latent_dim=32, hidden_dims=(512, 256), dropout=0.1, noise_probability=0.15)#
Bases:
AutoEncoderAutoencoder with input masking noise during training.
- Parameters:
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Patch-based transformer autoencoder for tabular single-cell inputs.
- class scdlkit.models.transformer.TransformerAutoEncoder(input_dim, latent_dim=32, hidden_dims=None, patch_size=16, d_model=128, n_heads=4, n_layers=2, decoder_hidden_dims=(256, 128), dropout=0.1)#
Bases:
BaseModelPatch-based transformer autoencoder.
- Parameters:
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
MLP classification baseline.
- class scdlkit.models.classifier.MLPClassifier(input_dim, num_classes, hidden_dims=(256, 128), dropout=0.2)#
Bases:
BaseModelSimple classifier over preprocessed expression features.
- forward(x)#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Comparison#
Compare multiple models on the same AnnData workflow.
- class scdlkit.evaluation.compare.BenchmarkResult(metrics_frame, runners, output_paths=<factory>)#
Bases:
objectCollected results from comparing multiple models.
- scdlkit.evaluation.compare.compare_models(adata, *, models, task, shared_kwargs=None, output_dir=None)#
Train and evaluate several models with shared configuration.