loki2.mil.models.src.model
Gated Attention Multiple Instance Learning (MIL) model.
This module implements a Gated Attention MIL model for binary classification of patient-level labels from bag-of-instances tile embeddings.
Module Contents
- class loki2.mil.models.src.model.GatedAttentionMIL(embedding_dim: int = 1280, M: int = 1280, L: int = 128, attention_branches: int = 1, dropout_rate: float = 0.6, norm_type: Literal['bn', 'ln', None] = 'bn', output_type: Literal['logits', 'probs'] = 'logits')
Bases:
torch.nn.ModuleGated Attention MIL model for patient-level classification.
This model uses a gated attention mechanism to aggregate tile-level embeddings into patient-level predictions. It supports multiple attention branches and can output either logits or probabilities.
- M
Feature extractor output dimension.
- L
Attention dimension.
- ATTENTION_BRANCHES
Number of attention branches.
- norm_type
Normalization type (“bn”, “ln”, or None).
- output_type
Output type (“logits” or “probs”).
- feature_extractor
Sequential feature extraction layers.
- attention_V
Attention value network (tanh activation).
- attention_U
Attention gate network (sigmoid activation).
- attention_w
Attention weight projection layer.
- classifier
Final classification layer.
- M
- L
- ATTENTION_BRANCHES
- norm_type
- output_type
- layers
- feature_extractor
- attention_V
- attention_U
- attention_w
- forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Forward pass through the model.
- Parameters:
x – Input tensor of shape (batch_size, bag_size, embedding_dim) containing bag-of-instances tile embeddings.
- Returns:
- A tuple containing:
Y_logits or Y_prob: Model output (logits or probabilities) of shape (batch_size, 1).
Y_hat: Binary predictions of shape (batch_size, 1).
A: Attention weights of shape (batch_size, attention_branches, bag_size).
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]