nichecompass.modules.compute_gp_group_lasso_reg_loss

nichecompass.modules.compute_gp_group_lasso_reg_loss(model)

Compute group lasso regularization loss for the masked decoder layer weights to enforce gene program sparsity (each gene program is a group; groups are normalized by the number of non-masked weights per group).

Check https://leimao.github.io/blog/Group-Lasso/ for more information about group lasso regularization.

Parameters:

model (Module) – The VGPGAE module.

Return type:

Tensor

Returns:

group_lasso_reg_loss: Group lasso regularization loss for the decoder layer weights.