loki2.models.utils.sam_utils ============================ .. py:module:: loki2.models.utils.sam_utils .. autoapi-nested-parse:: SAM (Segment Anything Model) utilities for Loki2. This module provides utilities adapted from Meta's Segment Anything Model (SAM) for use as backbones in cell segmentation networks. Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. This file has not been changed by the author of this repository @ Fabian Hörst, fabian.hoerst@uk-essen.de Institute for Artifical Intelligence in Medicine, University Medicine Essen Module Contents --------------- .. py:class:: MLPBlock(embedding_dim: int, mlp_dim: int, act: Type[torch.nn.Module] = nn.GELU) Bases: :py:obj:`torch.nn.Module` Multi-layer Perceptron block with two linear layers and activation. :param embedding_dim: Input and output embedding dimension. :param mlp_dim: Hidden dimension of the MLP. :param act: Activation function class. Defaults to nn.GELU. .. py:attribute:: lin1 .. py:attribute:: lin2 .. py:attribute:: act .. py:method:: forward(x: torch.Tensor) -> torch.Tensor .. py:class:: LayerNorm2d(num_channels: int, eps: float = 1e-06) Bases: :py:obj:`torch.nn.Module` 2D Layer Normalization for image tensors. :param num_channels: Number of channels in the input tensor. :param eps: Small epsilon value for numerical stability. Defaults to 1e-6. .. py:attribute:: weight .. py:attribute:: bias .. py:attribute:: eps .. py:method:: forward(x: torch.Tensor) -> torch.Tensor .. py:class:: ImageEncoderViT(img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, Ellipsis] = ()) Bases: :py:obj:`torch.nn.Module` Vision Transformer image encoder adapted from SAM/ViTDet. This class and its supporting functions are lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py :param img_size: Input image size. Defaults to 1024. :param patch_size: Patch size. Defaults to 16. :param in_chans: Number of input image channels. Defaults to 3. :param embed_dim: Patch embedding dimension. Defaults to 768. :param depth: Depth of ViT. Defaults to 12. :param num_heads: Number of attention heads in each ViT block. Defaults to 12. :param mlp_ratio: Ratio of mlp hidden dim to embedding dim. Defaults to 4.0. :param out_chans: Output channels. Defaults to 256. :param qkv_bias: If True, add a learnable bias to query, key, value. Defaults to True. :param norm_layer: Normalization layer. Defaults to nn.LayerNorm. :param act_layer: Activation layer. Defaults to nn.GELU. :param use_abs_pos: If True, use absolute positional embeddings. Defaults to True. :param use_rel_pos: If True, add relative positional embeddings to the attention map. Defaults to False. :param rel_pos_zero_init: If True, zero initialize relative positional parameters. Defaults to True. :param window_size: Window size for window attention blocks. Defaults to 0. :param global_attn_indexes: Indexes for blocks using global attention. Defaults to (). .. py:attribute:: img_size .. py:attribute:: patch_embed .. py:attribute:: pos_embed :type: Optional[torch.nn.Parameter] :value: None .. py:attribute:: blocks .. py:attribute:: neck .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Forward pass through the image encoder. :param x: Input image tensor of shape (B, C, H, W). :returns: Encoded features of shape (B, out_chans, H', W'). :rtype: torch.Tensor .. py:class:: Block(dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Optional[Tuple[int, int]] = None) Bases: :py:obj:`torch.nn.Module` Transformer block with support for window attention and residual propagation. :param dim: Number of input channels. :param num_heads: Number of attention heads in each ViT block. :param mlp_ratio: Ratio of mlp hidden dim to embedding dim. Defaults to 4.0. :param qkv_bias: If True, add a learnable bias to query, key, value. Defaults to True. :param norm_layer: Normalization layer. Defaults to nn.LayerNorm. :param act_layer: Activation layer. Defaults to nn.GELU. :param use_rel_pos: If True, add relative positional embeddings to the attention map. Defaults to False. :param rel_pos_zero_init: If True, zero initialize relative positional parameters. Defaults to True. :param window_size: Window size for window attention blocks. If it equals 0, then use global attention. Defaults to 0. :param input_size: Input resolution for calculating the relative positional parameter size. Defaults to None. .. py:attribute:: norm1 .. py:attribute:: attn .. py:attribute:: norm2 .. py:attribute:: mlp .. py:attribute:: window_size .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Forward pass through the transformer block. :param x: Input tensor. :returns: Output tensor after attention and MLP processing. :rtype: torch.Tensor .. py:class:: Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Optional[Tuple[int, int]] = None) Bases: :py:obj:`torch.nn.Module` Multi-head Attention block with relative position embeddings. :param dim: Number of input channels. :param num_heads: Number of attention heads. Defaults to 8. :param qkv_bias: If True, add a learnable bias to query, key, value. Defaults to True. :param use_rel_pos: If True, add relative positional embeddings to the attention map. Defaults to False. :param rel_pos_zero_init: If True, zero initialize relative positional parameters. Defaults to True. :param input_size: Input resolution for calculating the relative positional parameter size. Defaults to None. .. py:attribute:: num_heads .. py:attribute:: head_dim .. py:attribute:: scale .. py:attribute:: qkv .. py:attribute:: proj .. py:attribute:: use_rel_pos .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Forward pass through the attention block. :param x: Input tensor of shape (B, H, W, C). :returns: Output tensor of shape (B, H, W, C). :rtype: torch.Tensor .. py:function:: window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]] Partition input into non-overlapping windows with padding if needed. :param x: Input tokens with shape [B, H, W, C]. :param window_size: Window size for partitioning. :returns: - windows: Windows after partition with shape [B * num_windows, window_size, window_size, C]. - (Hp, Wp): Padded height and width before partition. :rtype: Tuple[torch.Tensor, Tuple[int, int]] .. py:function:: window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]) -> torch.Tensor Unpartition windows into original sequences and remove padding. :param windows: Input tokens with shape [B * num_windows, window_size, window_size, C]. :param window_size: Window size. :param pad_hw: Padded height and width (Hp, Wp). :param hw: Original height and width (H, W) before padding. :returns: Unpartitioned sequences with shape [B, H, W, C]. :rtype: torch.Tensor .. py:function:: get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor Get relative positional embeddings according to the relative positions of query and key sizes. :param q_size: Size of query q. :param k_size: Size of key k. :param rel_pos: Relative position embeddings of shape (L, C). :returns: Extracted positional embeddings according to relative positions. :rtype: torch.Tensor .. py:function:: add_decomposed_rel_pos(attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int]) -> torch.Tensor Calculate decomposed Relative Positional Embeddings from MViTv2. Reference: https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py :param attn: Attention map tensor. :param q: Query q in the attention layer with shape (B, q_h * q_w, C). :param rel_pos_h: Relative position embeddings (Lh, C) for height axis. :param rel_pos_w: Relative position embeddings (Lw, C) for width axis. :param q_size: Spatial sequence size of query q with (q_h, q_w). :param k_size: Spatial sequence size of key k with (k_h, k_w). :returns: Attention map with added relative positional embeddings. :rtype: torch.Tensor .. py:class:: PatchEmbed(kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768) Bases: :py:obj:`torch.nn.Module` Image to Patch Embedding module. Converts input images into patch embeddings using a convolutional layer. :param kernel_size: Kernel size of the projection layer. Defaults to (16, 16). :param stride: Stride of the projection layer. Defaults to (16, 16). :param padding: Padding size of the projection layer. Defaults to (0, 0). :param in_chans: Number of input image channels. Defaults to 3. :param embed_dim: Patch embedding dimension. Defaults to 768. .. py:attribute:: proj .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Forward pass through patch embedding. :param x: Input image tensor of shape (B, C, H, W). :returns: Patch embeddings of shape (B, H, W, C). :rtype: torch.Tensor