loki2.mil.models.src.model

Gated Attention Multiple Instance Learning (MIL) model.

This module implements a Gated Attention MIL model for binary classification of patient-level labels from bag-of-instances tile embeddings.

Module Contents

class loki2.mil.models.src.model.GatedAttentionMIL(embedding_dim: int = 1280, M: int = 1280, L: int = 128, attention_branches: int = 1, dropout_rate: float = 0.6, norm_type: Literal['bn', 'ln', None] = 'bn', output_type: Literal['logits', 'probs'] = 'logits')

Bases: torch.nn.Module

Gated Attention MIL model for patient-level classification.

This model uses a gated attention mechanism to aggregate tile-level embeddings into patient-level predictions. It supports multiple attention branches and can output either logits or probabilities.

M

Feature extractor output dimension.

L

Attention dimension.

ATTENTION_BRANCHES

Number of attention branches.

norm_type

Normalization type (“bn”, “ln”, or None).

output_type

Output type (“logits” or “probs”).

feature_extractor

Sequential feature extraction layers.

attention_V

Attention value network (tanh activation).

attention_U

Attention gate network (sigmoid activation).

attention_w

Attention weight projection layer.

classifier

Final classification layer.

M
L
ATTENTION_BRANCHES
norm_type
output_type
layers
feature_extractor
attention_V
attention_U
attention_w
forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Forward pass through the model.

Parameters:

x – Input tensor of shape (batch_size, bag_size, embedding_dim) containing bag-of-instances tile embeddings.

Returns:

A tuple containing:
  • Y_logits or Y_prob: Model output (logits or probabilities) of shape (batch_size, 1).

  • Y_hat: Binary predictions of shape (batch_size, 1).

  • A: Attention weights of shape (batch_size, attention_branches, bag_size).

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]