nichecompass.nn.MaskedLinear
- class nichecompass.nn.MaskedLinear(n_input, n_output, mask, bias=False)
Masked linear class.
Parts of the implementation are adapted from https://github.com/theislab/scarches/blob/master/scarches/models/expimap/modules.py#L9; 01.10.2022.
Uses static and dynamic binary masks to mask connections from the input layer to the output layer so that only unmasked connections can be used.
Methods table
|
Forward pass of the masked linear class. |
Methods
- MaskedLinear.forward(input, dynamic_mask=None)
Forward pass of the masked linear class.
- Parameters:
input (
Tensor) – Tensor containing the input features to the masked linear class.dynamic_mask (
Optional[Tensor] (default:None)) – Additional optional Tensor containing a mask that changes during training.
- Return type:
Tensor- Returns:
output: Tensor containing the output of the masked linear class (linear transformation of the input by only considering unmasked connections).