nichecompass.modules.compute_edge_recon_loss

nichecompass.modules.compute_edge_recon_loss(edge_recon_logits, edge_recon_labels, edge_incl=None)

Compute edge reconstruction weighted binary cross entropy loss with logits using ground truth edge labels and predicted edge logits.

Parameters:
  • edge_recon_logits (Tensor) – Predicted edge reconstruction logits for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).

  • edge_recon_labels (Tensor) – Edge ground truth labels for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).

  • edge_incl (Optional[Tensor] (default: None)) – Boolean mask which indicates edges to be included in the edge recon loss (dim: 2 * ´edge_batch_size´). If ´None´, includes all edges.

Return type:

Tensor

Returns:

edge_recon_loss: Weighted binary cross entropy loss between edge labels and predicted edge probabilities (calculated from logits for numerical stability in backpropagation).