nichecompass.modules.VGPGAE

class nichecompass.modules.VGPGAE(n_input, n_fc_layers_encoder, n_layers_encoder, n_hidden_encoder, n_prior_gp, n_addon_gp, cat_covariates_embeds_nums, n_output_genes, target_rna_decoder_mask, source_rna_decoder_mask, features_idx_dict, features_scale_factors, n_output_peaks=0, target_atac_decoder_mask=None, source_atac_decoder_mask=None, gene_peaks_mask=None, cat_covariates_cats=[], cat_covariates_no_edges=[], conv_layer_encoder='gcnconv', encoder_n_attention_heads=4, encoder_use_bn=False, dropout_rate_encoder=0.0, dropout_rate_graph_decoder=0.0, include_edge_recon_loss=True, include_gene_expr_recon_loss=True, include_chrom_access_recon_loss=True, include_cat_covariates_contrastive_loss=True, rna_recon_loss='nb', atac_recon_loss='nb', node_label_method='one-hop-norm', active_gp_thresh_ratio=0.03, active_gp_type='separate', log_variational=True, cat_covariates_embeds_injection=['gene_expr_decoder', 'chrom_access_decoder'], use_fc_decoder=False, fc_decoder_n_layers=2, include_edge_kl_loss=True)

Variational Gene Program Graph Autoencoder class.

Parameters:
  • n_input (int) – Number of nodes in the input layer.

  • n_fc_layers_encoder (int) – Number of fully connected layers in the encoder.

  • n_layers_encoder (int) – Number of message passing layers in the encoder.

  • n_hidden_encoder (int) – Number of nodes in the encoder hidden layer.

  • n_prior_gp (int) – Number of prior nodes in the latent space (gene programs from the gene program masks).

  • n_addon_gp (int) – Number of add-on nodes in the latent space (de-novo gene programs).

  • cat_covariates_embeds_nums (List[int]) – List of number of embedding nodes for all categorical covariates.

  • n_output_genes (int) – Number of output genes for the rna decoders.

  • target_rna_decoder_mask (Tensor) – Gene program mask for the target rna decoder.

  • source_rna_decoder_mask (Tensor) – Gene program mask for the source rna decoder.

  • features_idx_dict (dict) – Dictionary containing indices which omics features are masked and which are unmasked.

  • n_output_peaks (int (default: 0)) – Number of output peaks for the atac decoders.

  • target_atac_decoder_mask (Optional[Tensor] (default: None)) – Gene program mask for the target atac decoder.

  • source_atac_decoder_mask (Optional[Tensor] (default: None)) – Gene program mask for the source atac decoder.

  • gene_peaks_mask (Optional[Tensor] (default: None)) – A mask to map from genes to peaks, used to turn off peaks in the atac decoders if the corresponding genes have been turned off in the rna decoders by gene regularization.

  • cat_covariates_cats (List[List] (default: [])) – List of category lists for each categorical covariate for the categorical covariates embeddings.

  • cat_covariates_no_edges (List[bool] (default: [])) – List of booleans that indicate whether there can be edges between different categories of the categorical covariates. If this is ´True´ for a specific categorical covariate, this covariate will be excluded from the edge reconstruction loss.

  • conv_layer_encoder (Literal['gcnconv', 'gatv2conv'] (default: 'gcnconv')) – Convolutional layer used in the encoder.

  • encoder_n_attention_heads (int (default: 4)) – Only relevant if ´conv_layer_encoder == gatv2conv´. Number of attention heads used.

  • encoder_use_bn (bool (default: False)) – If ´True´, uses a batch normalization layer at the end of the encoder to normalize ´mu´.

  • dropout_rate_encoder (float (default: 0.0)) – Probability that nodes will be dropped in the encoder during training.

  • dropout_rate_graph_decoder (float (default: 0.0)) – Probability that nodes will be dropped in the graph decoder during training.

  • include_edge_recon_loss (bool (default: True)) – If True, includes the redge reconstruction loss in the loss optimization.

  • include_gene_expr_recon_loss (bool (default: True)) – If True, includes the gene expression reconstruction loss in the loss optimization.

  • include_chrom_access_recon_loss (bool (default: True)) – If True, includes the chromatin accessibility reconstruction loss in the loss optimization.

  • include_cat_covariates_contrastive_loss (bool (default: True)) – If True, includes the categorical covariates contrastive loss in the loss optimization.

  • rna_recon_loss (Literal['nb'] (default: 'nb')) – The loss used for gene expression reconstruction. If nb, uses a negative binomial loss.

  • node_label_method (Literal['one-hop-norm', 'one-hop-sum', 'one-hop-attention'] (default: 'one-hop-norm')) – Node label method that will be used for omics reconstruction. If ´one-hop-sum´, uses a concatenation of the node’s input features with the sum of the input features of all nodes in the node’s one-hop neighborhood. If ´one-hop-norm´, use a concatenation of the node`s input features with the node’s one-hop neighbors input features normalized as per Kipf, T. N. & Welling, M. Semi-Supervised Classification with Graph Convolutional Networks. arXiv [cs.LG] (2016). If ´one-hop-attention´, uses a concatenation of the node`s input features with the node’s one-hop neighbors input features weighted by an attention mechanism.

  • active_gp_thresh_ratio (float (default: 0.03)) – Ratio that determines which gene programs are considered active and are used for edge reconstruction and omics reconstruction. All inactive gene programs will be dropped out. Aggregations of the absolute values of the gene weights of the gene expression decoder per gene program are calculated. The maximum value, i.e. the value of the gene program with the highest aggregated value will be used as a benchmark and all gene programs whose aggregated value is smaller than ´active_gp_thresh_ratio´ times this maximum value will be set to inactive. If ´==0´, all gene programs will be considered active. More information can be found in ´self.get_active_gp_mask()´.

  • active_gp_type (Literal['mixed', 'separate'] (default: 'separate')) – Type to determine active gene programs. Can be ´mixed´, in which case active gene programs are determined across prior and add-on gene programs jointly or ´separate´ in which case they are determined separately for prior adn add-on gene programs.

  • log_variational (bool (default: True)) – If ´True´, transforms x by log(x+1) prior to encoding for numerical stability (not normalization).

  • cat_covariates_embeds_injection (Optional[List[Literal['encoder', 'gene_expr_decoder', 'chrom_access_decoder']]] (default: ['gene_expr_decoder', 'chrom_access_decoder'])) – List of VGPGAE modules in which the categorical covariates embeddings are injected.

  • use_fc_decoder (bool (default: False)) – If ´True´, uses a fully connected decoder instead of masked decoder. Just for ablation purposes.

  • fc_decoder_n_layers (int (default: 2)) – Number of layers to use if ´use_fc_decoder == True´.

Methods table

forward(data_batch, decoder[, ...])

Forward pass of the VGPGAE module.

get_active_gp_mask(self[, ...])

get_gp_weights(self[, only_masked_features, ...])

get_latent_representation(node_batch[, ...])

Encode input features ´x´ and ´edge_index´ into the latent distribution parameters and return either the distribution parameters themselves or latent features ´z´.

get_omics_decoder_outputs(node_batch[, ...])

Decode latent features ´z´ to return

load_and_expand_state_dict(model_state_dict)

Load model state dictionary into model and expand it to account for architectural changes through e.g. add-on nodes.

log_module_hyperparams_to_mlflow([excluded_attr])

Log module hyperparameters to Mlflow.

loss(edge_model_output, node_model_output, ...)

Calculate the optimization loss for backpropagation as well as the global loss that also contains components omitted from optimization (not backpropagated) but is used for early stopping evaluation.

reparameterize(mu, logstd)

Use reparameterization trick for latent space normal distribution.

Methods

VGPGAE.forward(data_batch, decoder, use_only_active_gps=False, return_agg_weights=False, update_atac_dynamic_decoder_mask=False)

Forward pass of the VGPGAE module.

Parameters:
  • data_batch (Data) – PyG Data object containing either an edge-level batch if ´decoder == graph´ or a node-level batch if ´decoder == omics´.

  • decoder (Literal['graph', 'omics']) – Decoder to use for the forward pass. Either ´graph´ for edge reconstruction or ´omics´ for gene expression and (if specified) chromatin accessibility reconstruction.

  • use_only_active_gps (bool (default: False)) – If ´True´, use only active gene programs as input to decoder.

  • return_agg_weights (bool (default: False)) – If ´True´, also return the aggregation weights of the node label aggregator.

  • update_atac_dynamic_decoder_mask (bool (default: False)) – If ´True´, turn off the mapped peaks for genes that have been turned off in a gene program (set peak gp weights to 0).

Return type:

dict

Returns:

output: Dictionary containing reconstructed edge logits if ´decoder == graph´ or the parameters of the omics feature distributions if ´decoder == omics´, as well as ´mu´ and ´logstd´ from the latent space distribution.

VGPGAE.get_active_gp_mask(abs_gp_weights_agg_mode='sum+nzmeans', return_gp_weights=False, normalize_gp_weights_with_features_scale_factors=False)

Get a mask of active gene programs based on the rna decoder gene weights of gene programs. Active gene programs are gene programs whose absolute gene weights aggregated over all genes are greater than ´self.active_gp_thresh_ratio_´ times the absolute gene weights aggregation of the gene program with the maximum value across all gene programs. Depending on ´abs_gp_weights_agg_mode´, the aggregation will be either a sum of absolute gene weights (prioritizes gene programs that reconstruct many genes) or a mean of non-zero absolute gene weights (normalizes for the number of genes that a gene program reconstructs) or a combination of the two.

Parameters:
  • abs_gp_weights_agg_mode (Literal['sum', 'nzmeans', 'sum+nzmeans', 'nzmedians', 'sum+nzmedians'] (default: 'sum+nzmeans')) – If ´sum´, uses sums of absolute gp weights for aggregation and active gp determination. If ´nzmeans´, uses means of non-zero absolute gp weights for aggregation and active gp determination. If ´sum+nzmeans´, uses a combination of sums and means of non-zero absolute gp weights for aggregation and active gp determination.

  • return_gp_weights (bool (default: False)) – If ´True´, in addition return the rna decoder gene weights of the active gene programs.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

active_gp_mask:

Boolean tensor of gene programs which contains True for active gene programs and False for inactive gene programs.

active_gp_weights:

Tensor containing the rna decoder gene weights of active gene programs.

VGPGAE.get_gp_weights(only_masked_features=False, gp_type='all')

Get the gene program weights of the omics feature decoders.

Return type:

List[Tensor]

Returns:

: gp_weights_all_modalities:

List of tensors containing the decoder gp weights for each omics modality (dim: (n_prior_gp + n_addon_gp) x n_omics_features)

VGPGAE.get_latent_representation(node_batch, only_active_gps=True, return_mu_std=False)

Encode input features ´x´ and ´edge_index´ into the latent distribution parameters and return either the distribution parameters themselves or latent features ´z´.

Parameters:
  • node_batch (Data) – PyG Data object containing a node-level batch.

  • only_active_gps (bool (default: True)) – If ´True´, return only the latent representation of active gps.

  • return_mu_std (bool (default: False)) – If ´True´, return ´mu´ and ´std´ instead of latent features ´z´.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

z:

Latent space features (dim: n_obs x n_gps or n_obs x n_active_gps).

mu:

Expected values of the latent posteriors (dim: n_obs x n_gps or n_obs x n_active_gps).

std:

Standard deviations of the latent posteriors (dim: n_obs x n_gps or n_obs x n_active_gps).

VGPGAE.get_omics_decoder_outputs(node_batch, only_active_gps=True)

Decode latent features ´z´ to return

Parameters:
  • node_batch (Data) – PyG Data object containing a node-level batch.

  • only_active_gps (bool (default: True)) – If ´True´, return only the latent representation of active gps.

  • cat_covariates_embed – Tensor containing the categorical covariates embedding (dim: n_obs x sum(cat_covariates_embeds_num)).

  • distribution (parameters and decode them to return the parameters of the)

  • distribution

  • reconstruction. (used for omics)

  • return_mu_std – If ´True´, return ´mu´ and ´std´ instead of latent features ´z´.

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

z:

Latent space features (dim: n_obs x n_gps or n_obs x n_active_gps).

mu:

Expected values of the latent posteriors (dim: n_obs x n_gps or n_obs x n_active_gps).

std:

Standard deviations of the latent posteriors (dim: n_obs x n_gps or n_obs x n_active_gps).

VGPGAE.load_and_expand_state_dict(model_state_dict)

Load model state dictionary into model and expand it to account for architectural changes through e.g. add-on nodes.

Parts of the implementation are adapted from https://github.com/theislab/scarches/blob/master/scarches/models/base/_base.py#L92 (01.10.2022).

VGPGAE.log_module_hyperparams_to_mlflow(excluded_attr=['features_idx_dict_', 'gene_peaks_mask_'])

Log module hyperparameters to Mlflow.

Parameters:

excluded_attr (list (default: ['features_idx_dict_', 'gene_peaks_mask_'])) – Attributes that are excluded despite being public because of length restrictions of mlflow.

VGPGAE.loss(edge_model_output, node_model_output, lambda_l1_masked, l1_targets_mask, l1_sources_mask, lambda_l1_addon, lambda_group_lasso, lambda_gene_expr_recon=300.0, lambda_chrom_access_recon=100.0, lambda_edge_recon=500000.0, lambda_cat_covariates_contrastive=100000.0, contrastive_logits_pos_ratio=0.125, contrastive_logits_neg_ratio=0.0, edge_recon_active=True, cat_covariates_contrastive_active=True)

Calculate the optimization loss for backpropagation as well as the global loss that also contains components omitted from optimization (not backpropagated) but is used for early stopping evaluation.

Parameters:
  • edge_model_output (dict) – Output of the edge-level forward pass for edge reconstruction.

  • node_model_output (dict) – Output of the node-level forward pass for omics reconstruction.

  • lambda_l1_masked (float) – Lambda (weighting factor) for the L1 regularization loss of genes in masked gene programs. If ´>0´, this will enforce sparsity of genes in masked gene programs.

  • l1_targets_mask (Tensor) – Boolean gene program gene mask that is True for all gene program target genes to which the L1 regularization loss should be applied (dim: n_genes, n_gps).

  • l1_sources_mask (Tensor) – Boolean gene program gene mask that is True for all gene program source genes to which the L1 regularization loss should be applied (dim: n_genes, n_gps).

  • lambda_l1_addon (float) – Lambda (weighting factor) for the L1 regularization loss of genes in addon gene programs. If ´>0´, this will enforce sparsity of genes in addon gene programs.

  • lambda_group_lasso (float) – Lambda (weighting factor) for the group lasso regularization loss of masked gene programs. If ´>0´, this will enforce sparsity of masked gene programs.

  • lambda_gene_expr_recon (float (default: 300.0)) – Lambda (weighting factor) for the gene expression reconstruction loss. If ´>0´, this will enforce interpretable gene programs that can be combined in a linear way to reconstruct gene expression.

  • lambda_chrom_access_recon (float (default: 100.0)) – Lambda (weighting factor) for the chromatin accessibility reconstruction loss. If ´>0´, this will enforce interpretable gene programs that can be combined in a linear way to reconstruct chromatin accessibility.

  • lambda_edge_recon (Optional[float] (default: 500000.0)) – Lambda (weighting factor) for the edge reconstruction loss. If ´>0´, this will enforce gene programs to be meaningful for edge reconstruction and, hence, to preserve spatial colocalization information.

  • lambda_cat_covariates_contrastive (Optional[float] (default: 100000.0)) – Lambda (weighting factor) for the categorical covariates contrastive loss. If ´>0´, this will enforce observations with different categorical covariates categories with very similar latent representations to become more similar, and observations with different latent representations to become more different.

  • contrastive_logits_pos_ratio (float (default: 0.125)) – Ratio for determining the logits threshold of positive contrastive examples of node pairs from different categorical covariates categories. The top (´contrastive_logits_pos_ratio´ * 100)% logits of node pairs from different categorical covariates categories serve as positive labels for the contrastive loss.

  • contrastive_logits_neg_ratio (float (default: 0.0)) – Ratio for determining the logits threshold of negative contrastive examples of node pairs from different categorical covariates categories. The bottom (´contrastive_logits_neg_ratio´ * 100)% logits of node pairs from different categorical covariates categories serve as negative labels for the contrastive loss.

  • edge_recon_active (bool (default: True)) – If ´True´, includes the edge reconstruction loss in the optimization / backpropagation. Setting this to ´False´ at the beginning of model training allows pretraining using other loss components.

  • cat_covariates_contrastive_active (bool (default: True)) – If ´True´, includes the categorical covariates contrastive loss in the optimization / backpropagation. Setting this to ´False´ at the beginning of model training allows pretraining using other loss components.

Return type:

dict

Returns:

loss_dict: Dictionary containing the loss used for backpropagation (loss_dict[“optim_loss”]), which consists of all loss components used for optimization, the global loss (loss_dict[“global_loss”]), which contains all loss components irrespective of whether they are used for optimization (needed as metric for early stopping and best model saving), as well as all individual loss components that contribute to the global loss.

VGPGAE.reparameterize(mu, logstd)

Use reparameterization trick for latent space normal distribution.

Parameters:
  • mu (Tensor) – Expected values of the latent space distribution (dim: n_obs, n_gps).

  • logstd (Tensor) – Log standard deviations of the latent space distribution (dim: n_obs, n_gps).

Return type:

Tensor

Returns:

rep: Reparameterized latent features (dim: n_obs, n_gps).