nichecompass.nn.MaskedLinear

class nichecompass.nn.MaskedLinear(n_input, n_output, mask, bias=False)

Masked linear class.

Parts of the implementation are adapted from https://github.com/theislab/scarches/blob/master/scarches/models/expimap/modules.py#L9; 01.10.2022.

Uses static and dynamic binary masks to mask connections from the input layer to the output layer so that only unmasked connections can be used.

Parameters:
  • n_input (int) – Number of input nodes to the masked layer.

  • n_output (int) – Number of output nodes from the masked layer.

  • mask (Tensor) – Static mask that is used to mask the node connections from the input layer to the output layer.

  • bias (default: False) – If ´True´, use a bias.

Methods table

forward(input[, dynamic_mask])

Forward pass of the masked linear class.

Methods

MaskedLinear.forward(input, dynamic_mask=None)

Forward pass of the masked linear class.

Parameters:
  • input (Tensor) – Tensor containing the input features to the masked linear class.

  • dynamic_mask (Optional[Tensor] (default: None)) – Additional optional Tensor containing a mask that changes during training.

Return type:

Tensor

Returns:

output: Tensor containing the output of the masked linear class (linear transformation of the input by only considering unmasked connections).