nichecompass.modules.compute_edge_recon_loss
- nichecompass.modules.compute_edge_recon_loss(edge_recon_logits, edge_recon_labels, edge_incl=None)
Compute edge reconstruction weighted binary cross entropy loss with logits using ground truth edge labels and predicted edge logits.
- Parameters:
edge_recon_logits (
Tensor) – Predicted edge reconstruction logits for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).edge_recon_labels (
Tensor) – Edge ground truth labels for both positive and negative sampled edges (dim: 2 * ´edge_batch_size´).edge_incl (
Optional[Tensor] (default:None)) – Boolean mask which indicates edges to be included in the edge recon loss (dim: 2 * ´edge_batch_size´). If ´None´, includes all edges.
- Return type:
Tensor- Returns:
edge_recon_loss: Weighted binary cross entropy loss between edge labels and predicted edge probabilities (calculated from logits for numerical stability in backpropagation).