Developer

Benchmarking

benchmarking.utils.compute_knn_graph_connectivities_and_distances(adata)

Compute approximate k-nearest-neighbors graph.

Data

data.initialize_dataloaders(node_masked_data)

Initialize edge-level and node-level training and validation dataloaders.

data.edge_level_split(data, edge_label_adj)

Split a PyG Data object into training, validation and test PyG Data objects using an edge-level split.

data.node_level_split_mask(data[, ...])

Split data on node-level into training, validation and test sets by adding node-level masks (train_mask, val_mask, test_mask) to the PyG Data object.

data.prepare_data(adata, ...[, adata_atac, ...])

Prepare data for model training including edge-level and node-level train, validation, and test splits.

data.SpatialAnnTorchDataset(adata, ...[, ...])

Spatially annotated torch dataset class to extract node features, node labels, adjacency matrix and edge indices in a standardized format from an AnnData object.

Models

models.utils.load_saved_files(dir_path, ...)

Helper to load saved model files.

models.utils.validate_var_names(adata, ...)

Helper to validate variable names.

models.utils.initialize_model(cls, adata, ...)

Helper to initialize a model.

Modules

modules.VGPGAE(n_input, n_fc_layers_encoder, ...)

Variational Gene Program Graph Autoencoder class.

modules.VGAEModuleMixin()

VGAE module mix in class containing universal VGAE module functionalities.

modules.BaseModuleMixin()

Base module mix in class containing universal module functionalities.

modules.compute_cat_covariates_contrastive_loss(...)

Compute categorical covariates contrastive weighted binary cross entropy loss with logits.

modules.compute_edge_recon_loss(...[, edge_incl])

Compute edge reconstruction weighted binary cross entropy loss with logits using ground truth edge labels and predicted edge logits.

modules.compute_gp_group_lasso_reg_loss(model)

Compute group lasso regularization loss for the masked decoder layer weights to enforce gene program sparsity (each gene program is a group; groups are normalized by the number of non-masked weights per group).

modules.compute_gp_l1_reg_loss(model, gp_type)

Compute L1 regularization loss for the rna decoder weights of gene programs of the type ´gp_type´ to encourage gene sparsity of those gene programs.

modules.compute_kl_reg_loss(mu, logstd)

Compute Kullback-Leibler divergence as per Kingma, D.

modules.compute_omics_recon_nb_loss(x, mu, theta)

Compute omics reconstruction loss according to a negative binomial model, which is often used to model omics count data such as scRNA-seq or scATAC-seq data.

NN

nn.OneHopAttentionNodeLabelAggregator(...[, ...])

One-hop Attention Node Label Aggregator class that uses a weighted sum of the omics features of a node's 1-hop neighbors to build an aggregated neighbor omics feature vector for a node.

nn.OneHopGCNNormNodeLabelAggregator(modality)

One-hop GCN Norm Node Label Aggregator class that uses a symmetrically normalized sum of the omics feature vector of a node's 1-hop neighbors to build an aggregated neighbor omics feature vector for a node.

nn.OneHopSumNodeLabelAggregator(modality)

One-hop Sum Node Label Aggregator class that sums up the omics features of a node's 1-hop neighbors to build an aggregated neighbor omics feature vector for a node.

nn.CosineSimGraphDecoder([dropout_rate])

Cosine similarity graph decoder class.

nn.FCOmicsFeatureDecoder(modality, entity, ...)

Fully connected omics feature decoder class.

nn.MaskedOmicsFeatureDecoder(modality, ...)

Masked omics feature decoder class.

nn.Encoder(n_input, ...[, n_addon_latent, ...])

Encoder class.

nn.MaskedLinear(n_input, n_output, mask[, bias])

Masked linear class.

nn.AddOnMaskedLayer(n_input, n_output, mask, ...)

Add-on masked layer class.

Train

train.Trainer(adata, model[, adata_atac, ...])

Trainer class.

train.eval_metrics(edge_recon_probs, edge_labels)

Get the evaluation metrics for a (balanced) sample of positive and negative edges and a sample of nodes.

train.plot_eval_metrics(eval_dict)

Plot evaluation metrics.