nichecompass.models.NicheCompass

class nichecompass.models.NicheCompass(adata, adata_atac=None, counts_key='counts', adj_key='spatial_connectivities', gp_names_key='nichecompass_gp_names', active_gp_names_key='nichecompass_active_gp_names', gp_targets_mask_key='nichecompass_gp_targets', gp_targets_categories_mask_key='nichecompass_gp_targets_categories', targets_categories_label_encoder_key='nichecompass_targets_categories_label_encoder', gp_sources_mask_key='nichecompass_gp_sources', gp_sources_categories_mask_key='nichecompass_gp_sources_categories', sources_categories_label_encoder_key='nichecompass_sources_categories_label_encoder', ca_targets_mask_key='nichecompass_ca_targets', ca_sources_mask_key='nichecompass_ca_sources', latent_key='nichecompass_latent', cat_covariates_embeds_keys=None, cat_covariates_embeds_injection=['gene_expr_decoder', 'chrom_access_decoder'], cat_covariates_keys=None, cat_covariates_no_edges=None, genes_idx_key='nichecompass_genes_idx', target_genes_idx_key='nichecompass_target_genes_idx', source_genes_idx_key='nichecompass_source_genes_idx', peaks_idx_key='nichecompass_peaks_idx', target_peaks_idx_key='nichecompass_target_peaks_idx', source_peaks_idx_key='nichecompass_source_peaks_idx', gene_peaks_mask_key='nichecompass_gene_peaks', recon_adj_key='nichecompass_recon_connectivities', agg_weights_key='nichecompass_agg_weights', include_edge_recon_loss=True, include_gene_expr_recon_loss=True, include_chrom_access_recon_loss=True, include_cat_covariates_contrastive_loss=False, gene_expr_recon_dist='nb', log_variational=True, node_label_method='one-hop-norm', active_gp_thresh_ratio=0.01, active_gp_type='separate', n_fc_layers_encoder=1, n_layers_encoder=1, n_hidden_encoder=None, conv_layer_encoder='gatv2conv', encoder_n_attention_heads=4, encoder_use_bn=False, dropout_rate_encoder=0.0, dropout_rate_graph_decoder=0.0, cat_covariates_cats=None, n_addon_gp=100, cat_covariates_embeds_nums=None, include_edge_kl_loss=True, use_cuda_if_available=True, seed=0, **kwargs)

NicheCompass model class.

Parameters:
  • adata (AnnData) – AnnData object with gene expression raw counts stored in ´adata.layers[counts_key]´ or ´adata.X´, depending on ´counts_key´, sparse adjacency matrix stored in ´adata.obsp[adj_key]´, gene program names stored in ´adata.uns[gp_names_key]´, and binary gene program targets and sources masks stored in ´adata.varm[gp_targets_mask_key]´ and ´adata.varm[gp_sources_mask_key]´ respectively.

  • adata_atac (Optional[AnnData] (default: None)) – Optional AnnData object with paired spatial chromatin accessibility raw counts stored in ´adata_atac.X´, and sparse boolean chromatin accessibility targets and sources masks stored in ´adata_atac.varm[ca_targets_mask_key]´ and ´adata_atac.varm[ca_sources_mask_key]´ respectively.

  • counts_key (Optional[str] (default: 'counts')) – Key under which the gene expression raw counts are stored in ´adata.layer´. If ´None´, uses ´adata.X´ as counts.

  • adj_key (str (default: 'spatial_connectivities')) – Key under which the sparse adjacency matrix is stored in ´adata.obsp´.

  • gp_names_key (str (default: 'nichecompass_gp_names')) – Key under which the gene program names are stored in ´adata.uns´.

  • active_gp_names_key (str (default: 'nichecompass_active_gp_names')) – Key under which the active gene program names will be stored in ´adata.uns´.

  • gp_targets_mask_key (str (default: 'nichecompass_gp_targets')) – Key under which the gene program targets mask is stored in ´adata.varm´.

  • gp_sources_mask_key (str (default: 'nichecompass_gp_sources')) – Key under which the gene program sources mask is stored in ´adata.varm´.

  • ca_targets_mask_key (Optional[str] (default: 'nichecompass_ca_targets')) – Key under which the chromatin accessibility targets mask is stored in ´adata_atac.varm´.

  • ca_sources_mask_key (Optional[str] (default: 'nichecompass_ca_sources')) – Key under which the chromatin accessibility sources mask is stored in ´adata_atac.varm´.

  • latent_key (str (default: 'nichecompass_latent')) – Key under which the latent / gene program representation of active gene programs will be stored in ´adata.obsm´ after model training.

  • cat_covariates_keys (Optional[List[str]] (default: None)) – Keys under which the categorical covariates are stored in ´adata.obs´.

  • cat_covariates_no_edges (Optional[List[bool]] (default: None)) – List of booleans that indicate whether there can be edges between different categories of the categorical covariates. If this is ´True´ for a specific categorical covariate, this covariate will be excluded from the edge reconstruction loss.

  • cat_covariates_embeds_keys (Optional[List[str]] (default: None)) – Keys under which the categorical covariates embeddings will be stored in ´adata.uns´.

  • cat_covariates_embeds_injection (Optional[List[Literal['encoder', 'gene_expr_decoder', 'chrom_access_decoder']]] (default: ['gene_expr_decoder', 'chrom_access_decoder'])) – List of VGPGAE modules in which the categorical covariates embeddings are injected.

  • genes_idx_key (str (default: 'nichecompass_genes_idx')) – Key in ´adata.uns´ where the index of a concatenated vector of target and source genes that are in the gene program masks are stored.

  • target_genes_idx_key (str (default: 'nichecompass_target_genes_idx')) – Key in ´adata.uns´ where the index of target genes that are in the gene program masks are stored.

  • source_genes_idx_key (str (default: 'nichecompass_source_genes_idx')) – Key in ´adata.uns´ where the index of source genes that are in the gene program masks are stored.

  • peaks_idx_key (str (default: 'nichecompass_peaks_idx')) – Key in ´adata_atac.uns´ where the index of a concatenated vector of target and source peaks that are in the chromatin accessibility masks are stored.

  • target_peaks_idx_key (str (default: 'nichecompass_target_peaks_idx')) – Key in ´adata_atac.uns´ where the index of target peaks that are in the chromatin accessibility masks are stored.

  • source_peaks_idx_key (str (default: 'nichecompass_source_peaks_idx')) – Key in ´adata_atac.uns´ where the index of source peaks that are in the chromatin accessibility masks are stored.

  • gene_peaks_mask_key (str (default: 'nichecompass_gene_peaks')) – Key in ´adata.varm´ where the gene peak mapping mask is stored.

  • recon_adj_key (Optional[str] (default: 'nichecompass_recon_connectivities')) – Key in ´adata.obsp´ where the reconstructed adjacency matrix edge probabilities will be stored.

  • agg_weights_key (Optional[str] (default: 'nichecompass_agg_weights')) – Key in ´adata.obsp´ where the aggregation weights of the node label aggregator will be stored.

  • include_edge_recon_loss (bool (default: True)) – If True, includes the edge reconstruction loss in the backpropagation.

  • include_gene_expr_recon_loss (bool (default: True)) – If True, includes the gene expression reconstruction loss in the backpropagation.

  • include_chrom_access_recon_loss (Optional[bool] (default: True)) – If True, includes the chromatin accessibility reconstruction loss in the backpropagation.

  • include_cat_covariates_contrastive_loss (bool (default: False)) – If True, includes the categorical covariates contrastive loss in the backpropagation.

  • gene_expr_recon_dist (Literal['nb'] (default: 'nb')) – The distribution used for gene expression reconstruction. If nb, uses a negative binomial distribution. If zinb, uses a zero-inflated negative binomial distribution.

  • log_variational (bool (default: True)) – If ´True´, transforms x by log(x+1) prior to encoding for numerical stability (not for normalization).

  • node_label_method (Literal['one-hop-sum', 'one-hop-norm', 'one-hop-attention'] (default: 'one-hop-norm')) – Node label method that will be used for omics reconstruction. If ´self´, uses only the input features of the node itself as node labels for omics reconstruction. If ´one-hop-sum´, uses a concatenation of the node’s input features with the sum of the input features of all nodes in the node’s one-hop neighborhood. If ´one-hop-norm´, uses a concatenation of the node`s input features with the node’s one-hop neighbors input features normalized as per Kipf, T. N. & Welling, M. Semi-Supervised Classification with Graph Convolutional Networks. arXiv [cs.LG] (2016). If ´one-hop-attention´, uses a concatenation of the node`s input features with the node’s one-hop neighbors input features weighted by an attention mechanism.

  • active_gp_thresh_ratio (float (default: 0.01)) – Ratio that determines which gene programs are considered active and are used in the latent representation after model training. All inactive gene programs will be dropped during model training after a determined number of epochs. Aggregations of the absolute values of the gene weights of the gene expression decoder per gene program are calculated. The maximum value, i.e. the value of the gene program with the highest aggregated value will be used as a benchmark and all gene programs whose aggregated value is smaller than ´active_gp_thresh_ratio´ times this maximum value will be set to inactive. If ´==0´, all gene programs will be considered active. More information can be found in ´self.model.get_active_gp_mask()´.

  • active_gp_type (Literal['mixed', 'separate'] (default: 'separate')) – Type to determine active gene programs. Can be ´mixed´, in which case active gene programs are determined across prior and add-on gene programs jointly or ´separate´ in which case they are determined separately for prior adn add-on gene programs.

  • n_fc_layers_encoder (int (default: 1)) – Number of fully connected layers in the encoder before message passing layers.

  • n_layers_encoder (int (default: 1)) – Number of message passing layers in the encoder.

  • n_hidden_encoder (Optional[int] (default: None)) – Number of nodes in the encoder hidden layers. If ´None´ is determined automatically based on the number of input genes and gene programs.

  • conv_layer_encoder (Literal['gcnconv', 'gatv2conv'] (default: 'gatv2conv')) – Convolutional layer used as GNN in the encoder.

  • encoder_n_attention_heads (Optional[int] (default: 4)) – Only relevant if ´conv_layer_encoder == gatv2conv´. Number of attention heads used in the GNN layers of the encoder.

  • encoder_use_bn (bool (default: False)) – If ´True´, uses a batch normalization layer at the end of the encoder to normalize ´mu´.

  • dropout_rate_encoder (float (default: 0.0)) – Probability that nodes will be dropped in the encoder during training.

  • dropout_rate_graph_decoder (float (default: 0.0)) – Probability that nodes will be dropped in the graph decoder during training.

  • cat_covariates_cats (Optional[List[List]] (default: None)) – List of category lists for each categorical covariate to get the right encoding when used after reloading.

  • n_addon_gp (int (default: 100)) – Number of addon gene programs (i.e. gene programs that are not included in masks but can be learned de novo).

  • cat_covariates_embeds_nums (Optional[List[int]] (default: None)) – List of number of embedding nodes for all categorical covariates.

  • use_cuda_if_available (bool (default: True)) – If True, use cuda if available.

  • seed (int (default: 0)) – Random seed to get reproducible results.

  • kwargs – NicheCompass kwargs (to support legacy versions).

Methods table

add_active_gp_scores_to_obs()

Add the expression of all active gene programs to ´adata.obs´.

compute_gp_gene_importances(selected_gp)

Compute gene importances for the genes of a given gene program.

compute_gp_peak_importances(selected_gp)

Compute peak importances for the peaks of a given gene program.

get_active_gps()

Get active gene programs based on the gene expression decoder gene weights of gene programs.

get_cat_covariates_embeds()

Get the categorical covariates embeddings.

get_gp_data([selected_gps])

Get the index of selected gene programs as well as their omics decoder weights.

get_gp_summary()

Get summary information of gene programs and return it as a DataFrame.

get_latent_representation([adata, ...])

Get the latent representation / gene program scores from a trained model.

get_neighbor_importances(self[, node_batch_size])

get_omics_decoder_outputs([adata, ...])

Get the omics decoder outputs.

get_recon_edge_probs(self[, ...])

load(dir_path[, adata, adata_atac, ...])

Instantiate a model from saved output.

run_differential_gp_tests(cat_key[, ...])

Run differential gene program tests by comparing gene program / latent scores between a category and specified comparison categories for all categories in ´selected_cats´ (by default all categories in ´adata.obs[cat_key]´).

save(dir_path[, overwrite, save_adata, ...])

Save model to disk (the Trainer optimizer state is not saved).

train([n_epochs, n_epochs_all_gps, ...])

Train the NicheCompass model.

Methods

NicheCompass.add_active_gp_scores_to_obs()

Add the expression of all active gene programs to ´adata.obs´.

Return type:

None

NicheCompass.compute_gp_gene_importances(selected_gp)

Compute gene importances for the genes of a given gene program. Gene importances are determined by the normalized weights of the rna decoders.

Parameters:

selected_gp (str) – Name of the gene program for which the gene importances should be retrieved.

Return type:

DataFrame

Returns:

gp_gene_importances_df: Pandas DataFrame containing genes, gene weights, gene importances and an indicator whether the gene belongs to the communication source or target, stored in ´gene_entity´.

NicheCompass.compute_gp_peak_importances(selected_gp)

Compute peak importances for the peaks of a given gene program. Peak importances are determined by the normalized weights of the atac decoders.

Parameters:

selected_gp (str) – Name of the gene program for which the peak importances should be retrieved.

Return type:

DataFrame

Returns:

gp_peak_importances_df: Pandas DataFrame containing peaks, peak weights, peak importances and an indicator whether the peak belongs to the communication source or target, stored in ´peak_entity´.

NicheCompass.get_active_gps()

Get active gene programs based on the gene expression decoder gene weights of gene programs. Active gene programs are gene programs whose absolute gene weights aggregated over all genes are greater than ´self.active_gp_thresh_ratio_´ times the absolute gene weights aggregation of the gene program with the maximum value across all gene programs.

Parameters:

adata – AnnData object to get the active gene programs for. If ´None´, uses the adata object stored in the model instance.

Return type:

ndarray

Returns:

active_gps: Gene program names of active gene programs (dim: n_active_gps,)

NicheCompass.get_cat_covariates_embeds()

Get the categorical covariates embeddings.

Return type:

ndarray

Returns:

: cat_covariates_embeds:

Categorical covariates embeddings.

NicheCompass.get_gp_data(selected_gps=None)

Get the index of selected gene programs as well as their omics decoder weights.

Return type:

Tuple[ndarray, ndarray, ndarray]

Parameters:

selected_gps:

Names of the selected gene programs for which data should be retrieved.

Returns:

: selected_gps_idx:

Index of the selected gene programs (dim: n_selected_gps,)

selected_gps_rna_decoder_weights:

Gene weights of the rna decoders of the selected gene programs (dim: (2 * n_genes) x n_selected_gps).

selected_gps_atac_decoder_weights:

Peak weights of the atac decoders of the selected gene programs (dim: (2 * n_peaks) x n_selected_gps).

NicheCompass.get_gp_summary()

Get summary information of gene programs and return it as a DataFrame.

Return type:

DataFrame

Returns:

gp_summary_df: DataFrame with gene program summary information.

NicheCompass.get_latent_representation(adata=None, adata_atac=None, counts_key='counts', adj_key='spatial_connectivities', cat_covariates_keys=None, only_active_gps=True, return_mu_std=False, node_batch_size=64, dtype=<class 'numpy.float64'>)

Get the latent representation / gene program scores from a trained model.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object to get the latent representation for. If ´None´, uses the adata object stored in the model instance.

  • counts_key (Optional[str] (default: 'counts')) – Key under which the counts are stored in ´adata.layer´. If ´None´, uses ´adata.X´ as counts.

  • adj_key (str (default: 'spatial_connectivities')) – Key under which the sparse adjacency matrix is stored in ´adata.obsp´.

  • cat_covariates_keys (Optional[List[str]] (default: None)) – Keys under which the categorical covariates are stored in ´adata.obs´.

  • only_active_gps (bool (default: True)) – If ´True´, return only the latent representation of active gps.

  • return_mu_std (bool (default: False)) – If True, return ´mu´ and ´std´ instead of latent features ´z´.

  • node_batch_size (int (default: 64)) – Batch size used during data loading.

  • dtype (type (default: <class 'numpy.float64'>)) – Precision to store the latent representations.

Return type:

Union[ndarray, Tuple[ndarray, ndarray]]

Returns:

z:

Latent space features (dim: n_obs x n_active_gps or n_obs x n_gps).

mu:

Expected values of the latent posterior (dim: n_obs x n_active_gps or n_obs x n_gps).

std:

Standard deviations of the latent posterior (dim: n_obs x n_active_gps or n_obs x n_gps).

NicheCompass.get_neighbor_importances(node_batch_size=None)

Get the aggregation weights of the node label aggregator. The aggregation weights indicate how much importance each node / observation has attributed to its neighboring nodes / observations for the omics reconstruction tasks. If ´one-hop-attention´ is used as node label method, the mean over all attention heads is used as aggregation weights.

Parameters:

node_batch_size (Optional[int] (default: None)) – Batch size that is used by the node-level dataloader. If ´None´, uses the node batch size used during model training.

Return type:

csr_matrix

Returns:

agg_weights: A sparse scipy matrix containing the aggregation weights of the node label aggregator (dim: n_obs x n_obs). Row-wise entries will be neighbor importances for each observation. The matrix is not symmetric.

NicheCompass.get_omics_decoder_outputs(adata=None, adata_atac=None, only_active_gps=True, node_batch_size=64)

Get the omics decoder outputs.

Parameters:
  • adata (Optional[AnnData] (default: None)) – AnnData object to get the latent representation for. If ´None´, uses the adata object stored in the model instance.

  • counts_key – Key under which the counts are stored in ´adata.layer´. If ´None´, uses ´adata.X´ as counts.

  • adj_key – Key under which the sparse adjacency matrix is stored in ´adata.obsp´.

  • cat_covariates_keys – Keys under which the categorical covariates are stored in ´adata.obs´.

  • only_active_gps (bool (default: True)) – If ´True´, return only the latent representation of active gps.

Return type:

Union[ndarray, Tuple[ndarray, ndarray]]

Returns:

output: A dictionary containing the omics decoder outputs.

NicheCompass.get_recon_edge_probs(node_batch_size=2048, device=None, edge_thresh=None, n_neighbors=None, return_edge_probs=False)

Get the reconstructed adjacency matrix (or edge probability matrix if ´return_edge_probs == True´ from a trained NicheCompass model.

Parameters:
  • node_batch_size (int (default: 2048)) – Batch size for batched decoder forward pass to alleviate memory consumption. Only relevant if ´return_edge_probs == False´.

  • device (Optional[str] (default: None)) – Device where the computation will be executed.

  • edge_thresh (Optional[float] (default: None)) – Probability threshold above or equal to which edge probabilities lead to a reconstructed edge. If ´None´, ´n_neighbors´ will be used to compute an independent edge threshold for each observation.

  • n_neighbors (Optional[int] (default: None)) – Number of neighbors used to compute an independent edge threshold for each observation (before the adjacency matrix is made symmetric).Only applies if ´edge_thresh is None´. In some occassions when multiple edges have the same probability, the number of reconstructed edges can slightly deviate from ´n_neighbors´. If ´None´, the number of neighbors in the original (symmetric) spatial graph stored in ´adata.obsp[self.adj_key_]´ are used to compute an independent edge threshold for each observation (in this case the adjacency matrix is not made symmetric).

  • return_edge_probs (bool (default: False)) – If ´True´, return a matrix of edge probabilities instead of the reconstructed adjacency matrix. This will require a lot of memory as a dense tensor will be returned instead of a sparse matrix.

Return type:

Union[csr_matrix, Tensor]

Returns:

adj_recon:

Sparse scipy matrix containing reconstructed edges (dim: n_nodes x n_nodes).

adj_recon_probs:

Tensor containing edge probabilities (dim: n_nodes x n_nodes).

classmethod NicheCompass.load(dir_path, adata=None, adata_atac=None, adata_file_name='adata.h5ad', adata_atac_file_name=None, use_cuda=False, n_addon_gps=0, gp_names_key=None, genes_idx_key=None, unfreeze_all_weights=False, unfreeze_addon_gp_weights=False, unfreeze_cat_covariates_embedder_weights=False)

Instantiate a model from saved output. Can be used for transfer learning scenarios and to learn de-novo gene programs by adding add-on gene programs and freezing non add-on weights.

Parameters:
  • dir_path (str) – Path to saved outputs.

  • adata (Optional[AnnData] (default: None)) – AnnData organized in the same way as data used to train the model. If ´None´, will check for and load adata saved with the model.

  • adata_atac (Optional[AnnData] (default: None)) – ATAC AnnData organized in the same way as data used to train the model. If ´None´ and ´adata_atac_file_name´ is not ´None´, will check for and load adata_atac saved with the model.

  • adata_file_name (str (default: 'adata.h5ad')) – File name of the AnnData object to be loaded.

  • adata_atac_file_name (Optional[str] (default: None)) – File name of the ATAC AnnData object to be loaded.

  • use_cuda (bool (default: False)) – If True, load model on GPU.

  • n_addon_gps (int (default: 0)) – Number of (new) add-on gene programs to be added to the model’s architecture.

  • gp_names_key (Optional[str] (default: None)) – Key under which the gene program names are stored in ´adata.uns´.

  • unfreeze_all_weights (bool (default: False)) – If True, unfreeze all weights.

  • unfreeze_addon_gp_weights (bool (default: False)) – If True, unfreeze addon gp weights.

  • unfreeze_cat_covariates_embedder_weights (bool (default: False)) – If True, unfreeze categorical covariates embedder weights.

Return type:

Module

Returns:

model: Model with loaded state dictionaries and, if specified, frozen non add-on weights.

NicheCompass.run_differential_gp_tests(cat_key, selected_cats=None, comparison_cats='rest', selected_gps=None, n_sample=10000, log_bayes_factor_thresh=2.3, key_added='nichecompass_differential_gp_test_results', seed=0, adata=None)

Run differential gene program tests by comparing gene program / latent scores between a category and specified comparison categories for all categories in ´selected_cats´ (by default all categories in ´adata.obs[cat_key]´). Enriched category gene programs are determined through the log Bayes Factor between the hypothesis h0 that the (normalized) gene program / latent scores of observations of the category under consideration (z0) are higher than the (normalized) gene program / latent scores of observations of the comparison categories (z1) versus the alternative hypothesis h1 that the (normalized) gene program / latent scores of observations of the comparison categories (z1) are higher or equal to the (normalized) gene program / latent scores of observations of the category under consideration (z0). The results of the differential tests including the log Bayes Factors for enriched category gene programs are stored in a pandas DataFrame under ´adata.uns[key_added]´. The DataFrame also stores p_h0, the probability that z0 > z1 and p_h1, the probability that z1 >= z0. The rows are ordered by the log Bayes Factor. In addition, the (normalized) gene program / latent scores of enriched gene programs across any of the categories are stored in ´adata.obs´.

Parts of the implementation are adapted from Lotfollahi, M. et al. Biologically informed deep learning to query gene programs in single-cell atlases. Nat. Cell Biol. 25, 337–350 (2023); https://github.com/theislab/scarches/blob/master/scarches/models/expimap/expimap_model.py#L429 (24.11.2022).

Parameters:
  • cat_key (str) – Key under which the categories and comparison categories are stored in ´adata.obs´.

  • selected_cats (Union[str, list, None] (default: None)) – List of category labels for which differential tests will be run. If ´None´, uses all category labels from ´adata.obs[cat_key]´.

  • comparison_cats (Union[str, list] (default: 'rest')) – Categories used as comparison group. If ´rest´, all categories other than the category under consideration are used as comparison group.

  • selected_gps (Union[str, list, None] (default: None)) – List of gene program names for which differential tests will be run. If ´None´, uses all active gene programs.

  • n_sample (int (default: 10000)) – Number of observations to be drawn from the category and comparison categories for the log Bayes Factor computation.

  • log_bayes_factor_thresh (float (default: 2.3)) – Log bayes factor threshold. Category gene programs with a higher absolute score than this threshold are considered enriched.

  • key_added (str (default: 'nichecompass_differential_gp_test_results')) – Key under which the test results pandas DataFrame is stored in ´adata.uns´.

  • seed (int (default: 0)) – Random seed for reproducible sampling.

  • adata (Optional[AnnData] (default: None)) – AnnData object to be used. If ´None´, uses the adata object stored in the model instance.

Return type:

list

Returns:

enriched_gps: Names of enriched gene programs across all categories (duplicate gene programs that appear for multiple catgories are only considered once).

NicheCompass.save(dir_path, overwrite=False, save_adata=False, adata_file_name='adata.h5ad', save_adata_atac=False, adata_atac_file_name='adata_atac.h5ad', **anndata_write_kwargs)

Save model to disk (the Trainer optimizer state is not saved).

Parameters:
  • dir_path (str) – Path of the directory where the model will be saved.

  • overwrite (bool (default: False)) – If True, overwrite existing data. If False and directory already exists at dir_path, error will be raised.

  • save_adata (bool (default: False)) – If True, also saves the AnnData object.

  • adata_file_name (str (default: 'adata.h5ad')) – File name under which the AnnData object will be saved.

  • save_adata_atac (bool (default: False)) – If True, also saves the ATAC AnnData object.

  • adata_atac_file_name (str (default: 'adata_atac.h5ad')) – File name under which the ATAC AnnData object will be saved.

  • adata_write_kwargs – Kwargs for adata write function.

NicheCompass.train(n_epochs=100, n_epochs_all_gps=25, n_epochs_no_edge_recon=0, n_epochs_no_cat_covariates_contrastive=5, lr=0.001, weight_decay=0.0, lambda_edge_recon=500000.0, lambda_gene_expr_recon=300.0, lambda_chrom_access_recon=100.0, lambda_cat_covariates_contrastive=0.0, contrastive_logits_pos_ratio=0.0, contrastive_logits_neg_ratio=0.0, lambda_group_lasso=0.0, lambda_l1_masked=0.0, l1_targets_categories=['target_gene'], l1_sources_categories=None, lambda_l1_addon=30.0, edge_val_ratio=0.1, node_val_ratio=0.1, edge_batch_size=256, node_batch_size=None, mlflow_experiment_id=None, retrieve_cat_covariates_embeds=False, retrieve_recon_edge_probs=False, retrieve_agg_weights=False, use_cuda_if_available=True, n_sampled_neighbors=-1, latent_dtype=<class 'numpy.float64'>, **trainer_kwargs)

Train the NicheCompass model.

Parameters:
  • n_epochs (int (default: 100)) – Number of epochs.

  • n_epochs_all_gps (int (default: 25)) – Number of epochs during which all gene programs are used for model training. After that only active gene programs are retained.

  • n_epochs_no_edge_recon (int (default: 0)) – Number of epochs during which the edge reconstruction loss is excluded from backpropagation for pretraining using the other loss components.

  • n_epochs_no_cat_covariates_contrastive (int (default: 5)) – Number of epochs during which the categorical covariates contrastive loss is excluded from backpropagation for pretraining using the other loss components.

  • lr (float (default: 0.001)) – Learning rate.

  • weight_decay (float (default: 0.0)) – Weight decay (L2 penalty).

  • lambda_edge_recon (Optional[float] (default: 500000.0)) – Lambda (weighting factor) for the edge reconstruction loss. If ´>0´, this will enforce gene programs to be meaningful for edge reconstruction and, hence, to preserve spatial colocalization information.

  • lambda_gene_expr_recon (float (default: 300.0)) – Lambda (weighting factor) for the gene expression reconstruction loss. If ´>0´, this will enforce interpretable gene programs that can be combined in a linear way to reconstruct gene expression.

  • lambda_chrom_access_recon (float (default: 100.0)) – Lambda (weighting factor) for the chromatin accessibility reconstruction loss. If ´>0´, this will enforce interpretable gene programs that can be combined in a linear way to reconstruct chromatin accessibility.

  • lambda_cat_covariates_contrastive (float (default: 0.0)) – Lambda (weighting factor) for the categorical covariates contrastive loss. If ´>0´, this will enforce observations with different categorical covariates categories with very similar latent representations to become more similar, and observations with different latent representations to become more different.

  • contrastive_logits_pos_ratio (float (default: 0.0)) – Ratio for determining the logits threshold of positive contrastive examples of node pairs from different categorical covariates categories. The top (´contrastive_logits_pos_ratio´ * 100)% logits of node pairs from different categorical covariates categories serve as positive labels for the contrastive loss.

  • contrastive_logits_neg_ratio (float (default: 0.0)) – Ratio for determining the logits threshold of negative contrastive examples of node pairs from different categorical covariates categories. The bottom (´contrastive_logits_neg_ratio´ * 100)% logits of node pairs from different categorical covariates categories serve as negative labels for the contrastive loss.

  • lambda_group_lasso (float (default: 0.0)) – Lambda (weighting factor) for the group lasso regularization loss of gene programs. If ´>0´, this will enforce sparsity of gene programs.

  • lambda_l1_masked (float (default: 0.0)) – Lambda (weighting factor) for the L1 regularization loss of genes in masked gene programs. If ´>0´, this will enforce sparsity of genes in masked gene programs.

  • l1_targets_categories (Optional[list] (default: ['target_gene'])) – Gene program mask targets categories for which l1 regularization loss will be applied.

  • l1_sources_categories (Optional[list] (default: None)) – Gene program mask sources categories for which l1 regularization loss will be applied.

  • lambda_l1_addon (float (default: 30.0)) – Lambda (weighting factor) for the L1 regularization loss of genes in addon gene programs. If ´>0´, this will enforce sparsity of genes in addon gene programs.

  • edge_val_ratio (float (default: 0.1)) – Fraction of the data that is used as validation set on edge-level. The rest of the data will be used as training set on edge-level.

  • node_val_ratio (float (default: 0.1)) – Fraction of the data that is used as validation set on node-level. The rest of the data will be used as training set on node-level.

  • edge_batch_size (int (default: 256)) – Batch size for the edge-level dataloaders.

  • node_batch_size (Optional[int] (default: None)) – Batch size for the node-level dataloaders. If ´None´, is automatically determined based on ´edge_batch_size´.

  • mlflow_experiment_id (Optional[str] (default: None)) – ID of the Mlflow experiment used for tracking training parameters and metrics.

  • retrieve_cat_covariates_embeds (bool (default: False)) – If ´True´, retrieve the categorical covariates embeddings after model training is finished if multiple categorical covariates categories are present.

  • retrieve_recon_edge_probs (bool (default: False)) – If ´True´, retrieve the reconstructed edge probabilities after model training is finished.

  • retrieve_agg_weights (bool (default: False)) – If ´True´, retrieve the node label aggregation weights after model training is finished.

  • use_cuda_if_available (bool (default: True)) – If True, use cuda if available.

  • n_sampled_neighbors (int (default: -1)) – Number of neighbors that are sampled during model training from the spatial neighborhood graph.

  • latent_dtype (type (default: <class 'numpy.float64'>)) – Data type for storing the latent representations. Set to np.float16 for really big datasets (>1m observations).

  • trainer_kwargs – Kwargs for the model Trainer.