nichecompass.train.Trainer

class nichecompass.train.Trainer(adata, model, adata_atac=None, counts_key='counts', adj_key='spatial_connectivities', cat_covariates_keys=None, gp_targets_mask_key='nichecompass_gp_targets', gp_sources_mask_key='nichecompass_gp_sources', edge_val_ratio=0.1, node_val_ratio=0.1, edge_batch_size=512, node_batch_size=None, n_sampled_neighbors=-1, use_early_stopping=True, reload_best_model=True, early_stopping_kwargs=None, use_cuda_if_available=True, seed=0, monitor=True, verbose=False, **kwargs)

Trainer class. Encapsulates all logic for NicheCompass model training preparation and model training.

Parts of the implementation are inspired by https://github.com/theislab/scarches/blob/master/scarches/trainers/trvae/trainer.py#L13 (01.10.2022)

Parameters:
  • adata (AnnData) – AnnData object with counts stored in ´adata.layers[counts_key]´ or ´adata.X´ depending on ´counts_key´ and sparse adjacency matrix stored in ´adata.obsp[adj_key]´.

  • adata_atac (Optional[AnnData] (default: None)) – Additional optional AnnData object with paired spatial ATAC data.

  • model (Module) – An NicheCompass module model instance.

  • counts_key (Optional[str] (default: 'counts')) – Key under which the counts are stored in ´adata.layer´. If ´None´, uses ´adata.X´ as counts.

  • adj_key (str (default: 'spatial_connectivities')) – Key under which the sparse adjacency matrix is stored in ´adata.obsp´.

  • cat_covariates_keys (Optional[List[str]] (default: None)) – Keys under which the categorical covariates are stored in ´adata.obs´.

  • gp_targets_mask_key (str (default: 'nichecompass_gp_targets')) – Key under which the gene program targets mask is stored in ´model.adata.varm´. This mask will only be used if no ´gp_targets_mask´ is passed explicitly to the model.

  • gp_sources_mask_key (str (default: 'nichecompass_gp_sources')) – Key under which the gene program sources mask is stored in ´model.adata.varm´. This mask will only be used if no ´gp_sources_mask´ is passed explicitly to the model.

  • edge_val_ratio (float (default: 0.1)) – Fraction of the data that is used as validation set on edge-level. The rest of the data will be used as training set on edge-level.

  • node_val_ratio (float (default: 0.1)) – Fraction of the data that is used as validation set on node-level. The rest of the data will be used as training set on edge-level.

  • edge_batch_size (int (default: 512)) – Batch size for the edge-level dataloaders.

  • node_batch_size (Optional[int] (default: None)) – Batch size for the node-level dataloaders.

  • n_sampled_neighbors (int (default: -1)) – Number of neighbors that are sampled during model training from the spatial neighborhood graph.

  • use_early_stopping (bool (default: True)) – If True, the EarlyStopping class is used to prevent overfitting.

  • reload_best_model (bool (default: True)) – If True, the best state of the model with respect to the early stopping criterion is reloaded at the end of training.

  • early_stopping_kwargs (Optional[dict] (default: None)) – Kwargs for the EarlyStopping class.

  • use_cuda_if_available (bool (default: True)) – If True, use cuda if available.

  • seed (int (default: 0)) – Random seed to get reproducible results.

  • monitor (bool (default: True)) – If ´True´, the progress of training will be printed after each epoch.

  • verbose (bool (default: False)) – If ´True´, print out detailed training progress of individual losses.

Methods table

eval_end

eval_epoch

is_early_stopping()

Check whether to apply early stopping, update learning rate and save best model state.

train([n_epochs, n_epochs_all_gps, ...])

Train the NicheCompass model.

Methods

Trainer.eval_end()

End evaluation logic of NicheCompass model used after model training.

Trainer.eval_epoch()

Epoch evaluation logic of NicheCompass model used during training.

Trainer.is_early_stopping()

Check whether to apply early stopping, update learning rate and save best model state.

Return type:

bool

Returns:

stop_training: If True, stop NicheCompass model training.

Trainer.train(n_epochs=100, n_epochs_all_gps=25, n_epochs_no_edge_recon=0, n_epochs_no_cat_covariates_contrastive=5, lr=0.001, weight_decay=0.0, lambda_edge_recon=500000.0, lambda_cat_covariates_contrastive=0.0, contrastive_logits_pos_ratio=0.125, contrastive_logits_neg_ratio=0.125, lambda_gene_expr_recon=100.0, lambda_chrom_access_recon=10.0, lambda_group_lasso=0.0, lambda_l1_masked=0.0, l1_targets_mask=None, l1_sources_mask=None, lambda_l1_addon=0.0, mlflow_experiment_id=None)

Train the NicheCompass model.

Parameters:
  • n_epochs (int (default: 100)) – Number of epochs.

  • n_epochs_all_gps (int (default: 25)) – Number of epochs during which all gene programs are used for model training. After that only active gene programs are retained.

  • n_epochs_no_edge_recon (int (default: 0)) – Number of epochs without edge reconstruction loss for gene expression decoder pretraining.

  • lr (float (default: 0.001)) – Learning rate.

  • weight_decay (float (default: 0.0)) – Weight decay (L2 penalty).

  • lambda_edge_recon (Optional[float] (default: 500000.0)) – Lambda (weighting factor) for the edge reconstruction loss. If ´>0´, this will enforce gene programs to be meaningful for edge reconstruction and, hence, to preserve spatial colocalization information.

  • lambda_cat_covariates_contrastive (Optional[float] (default: 0.0)) – Lambda (weighting factor) for the categorical covariates contrastive loss. If ´>0´, this will enforce observations with different categorical covariates categories with very similar latent representations to become more similar, and observations with different latent representations to become more different.

  • contrastive_logits_pos_ratio (Optional[float] (default: 0.125)) – Ratio for determining the logits threshold of positive contrastive examples of node pairs from different categorical covariates categories. The top (´contrastive_logits_pos_ratio´ * 100)% logits of node pairs from different categorical covariates categories serve as positive labels for the contrastive loss.

  • contrastive_logits_neg_ratio (Optional[float] (default: 0.125)) – Ratio for determining the logits threshold of negative contrastive examples of node pairs from different categorical covariates categories. The bottom (´contrastive_logits_neg_ratio´ * 100)% logits of node pairs from different categorical covariates categories serve as negative labels for the contrastive loss.

  • lambda_gene_expr_recon (float (default: 100.0)) – Lambda (weighting factor) for the gene expression reconstruction loss. If ´>0´, this will enforce interpretable gene programs that can be combined in a linear way to reconstruct gene expression.

  • lambda_chrom_access_recon (float (default: 10.0)) – Lambda (weighting factor) for the chromatin accessibility reconstruction loss. If ´>0´, this will enforce interpretable gene programs that can be combined in a linear way to reconstruct chromatin accessibility.

  • lambda_group_lasso (float (default: 0.0)) – Lambda (weighting factor) for the group lasso regularization loss of gene programs. If ´>0´, this will enforce sparsity of gene programs.

  • lambda_l1_masked (float (default: 0.0)) – Lambda (weighting factor) for the L1 regularization loss of genes in masked gene programs. If ´>0´, this will enforce sparsity of genes in masked gene programs.

  • l1_targets_mask (Optional[Tensor] (default: None)) – Boolean gene program gene mask that is True for all gene program target genes to which the L1 regularization loss should be applied (dim: n_genes, n_gps).

  • l1_sources_mask (Optional[Tensor] (default: None)) – Boolean gene program gene mask that is True for all gene program source genes to which the L1 regularization loss should be applied (dim: n_genes, n_gps).

  • lambda_l1_addon (float (default: 0.0)) – Lambda (weighting factor) for the L1 regularization loss of genes in addon gene programs. If ´>0´, this will enforce sparsity of genes in addon gene programs.

  • mlflow_experiment_id (Optional[str] (default: None)) – ID of the mlflow experiment that will be used for tracking.