loki2.mil.models.src.model ========================== .. py:module:: loki2.mil.models.src.model .. autoapi-nested-parse:: Gated Attention Multiple Instance Learning (MIL) model. This module implements a Gated Attention MIL model for binary classification of patient-level labels from bag-of-instances tile embeddings. Module Contents --------------- .. py:class:: GatedAttentionMIL(embedding_dim: int = 1280, M: int = 1280, L: int = 128, attention_branches: int = 1, dropout_rate: float = 0.6, norm_type: Literal['bn', 'ln', None] = 'bn', output_type: Literal['logits', 'probs'] = 'logits') Bases: :py:obj:`torch.nn.Module` Gated Attention MIL model for patient-level classification. This model uses a gated attention mechanism to aggregate tile-level embeddings into patient-level predictions. It supports multiple attention branches and can output either logits or probabilities. .. attribute:: M Feature extractor output dimension. .. attribute:: L Attention dimension. .. attribute:: ATTENTION_BRANCHES Number of attention branches. .. attribute:: norm_type Normalization type ("bn", "ln", or None). .. attribute:: output_type Output type ("logits" or "probs"). .. attribute:: feature_extractor Sequential feature extraction layers. .. attribute:: attention_V Attention value network (tanh activation). .. attribute:: attention_U Attention gate network (sigmoid activation). .. attribute:: attention_w Attention weight projection layer. .. attribute:: classifier Final classification layer. .. py:attribute:: M .. py:attribute:: L .. py:attribute:: ATTENTION_BRANCHES .. py:attribute:: norm_type .. py:attribute:: output_type .. py:attribute:: layers .. py:attribute:: feature_extractor .. py:attribute:: attention_V .. py:attribute:: attention_U .. py:attribute:: attention_w .. py:method:: forward(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Forward pass through the model. :param x: Input tensor of shape (batch_size, bag_size, embedding_dim) containing bag-of-instances tile embeddings. :returns: A tuple containing: - Y_logits or Y_prob: Model output (logits or probabilities) of shape (batch_size, 1). - Y_hat: Binary predictions of shape (batch_size, 1). - A: Attention weights of shape (batch_size, attention_branches, bag_size). :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]