nichecompass.modules.compute_cat_covariates_contrastive_loss

nichecompass.modules.compute_cat_covariates_contrastive_loss(edge_recon_logits, edge_recon_labels, edge_same_cat_covariates_cat=None, contrastive_logits_pos_ratio=0.0, contrastive_logits_neg_ratio=0.0)

Compute categorical covariates contrastive weighted binary cross entropy loss with logits. The loss is computed for each categorical covariate separately and added up. Sampled edges with nodes from different categories whose edge reconstruction logits are among the top (´contrastive_logits_pos_ratio´ * 100)% logits are considered positive examples for a specific categorical covariate. Sampled edges with nodes from different categories whose edge reconstruction logits are among the bottom (´contrastive_logits_neg_ratio´ * 100)% logits are considered negative examples.

Parameters:
  • edge_recon_logits (Tensor) – Predicted edge reconstruction logits for both positive and negative sampled edges (dim: 2 * edge_batch_size).

  • edge_recon_labels (Tensor) – Edge ground truth labels for both positive and negative sampled edges (dim: 2 * edge_batch_size).

  • edge_same_cat_covariates_cat (Optional[List[Tensor]] (default: None)) – List of boolean tensors indicating whether the edge node pair has the same categorical covariate category for each categorical covariate respectively, and for both positive and negative sampled edges (dim of tensors: 2 * edge_batch_size).

  • contrastive_logits_pos_ratio (float (default: 0.0)) – Ratio for determining the logits threshold of positive contrastive examples of node pairs from different categories. The top (´contrastive_logits_pos_ratio´ * 100)% logits of node pairs from different categories serve as positive labels for the contrastive loss.

  • contrastive_logits_neg_ratio (float (default: 0.0)) – Ratio for determining the logits threshold of negative contrastive examples of node pairs from different categories. The bottom (´contrastive_logits_neg_ratio´ * 100)% logits of node pairs from different categories serve as negative labels for the contrastive loss.

Return type:

Tensor

Returns:

cat_covariates_contrastive_loss: Categorical covariates contrastive binary cross entropy loss (calculated from logits for numerical stability in backpropagation, and summed up over all categorical covariates).