loki2.models.utils.sam_utils
SAM (Segment Anything Model) utilities for Loki2.
This module provides utilities adapted from Meta’s Segment Anything Model (SAM) for use as backbones in cell segmentation networks.
Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.
This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.
This file has not been changed by the author of this repository
@ Fabian Hörst, fabian.hoerst@uk-essen.de Institute for Artifical Intelligence in Medicine, University Medicine Essen
Module Contents
- class loki2.models.utils.sam_utils.MLPBlock(embedding_dim: int, mlp_dim: int, act: Type[torch.nn.Module] = nn.GELU)
Bases:
torch.nn.ModuleMulti-layer Perceptron block with two linear layers and activation.
- Parameters:
embedding_dim – Input and output embedding dimension.
mlp_dim – Hidden dimension of the MLP.
act – Activation function class. Defaults to nn.GELU.
- lin1
- lin2
- act
- forward(x: torch.Tensor) torch.Tensor
- class loki2.models.utils.sam_utils.LayerNorm2d(num_channels: int, eps: float = 1e-06)
Bases:
torch.nn.Module2D Layer Normalization for image tensors.
- Parameters:
num_channels – Number of channels in the input tensor.
eps – Small epsilon value for numerical stability. Defaults to 1e-6.
- weight
- bias
- eps
- forward(x: torch.Tensor) torch.Tensor
- class loki2.models.utils.sam_utils.ImageEncoderViT(img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, Ellipsis] = ())
Bases:
torch.nn.ModuleVision Transformer image encoder adapted from SAM/ViTDet.
This class and its supporting functions are lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py
- Parameters:
img_size – Input image size. Defaults to 1024.
patch_size – Patch size. Defaults to 16.
in_chans – Number of input image channels. Defaults to 3.
embed_dim – Patch embedding dimension. Defaults to 768.
depth – Depth of ViT. Defaults to 12.
num_heads – Number of attention heads in each ViT block. Defaults to 12.
mlp_ratio – Ratio of mlp hidden dim to embedding dim. Defaults to 4.0.
out_chans – Output channels. Defaults to 256.
qkv_bias – If True, add a learnable bias to query, key, value. Defaults to True.
norm_layer – Normalization layer. Defaults to nn.LayerNorm.
act_layer – Activation layer. Defaults to nn.GELU.
use_abs_pos – If True, use absolute positional embeddings. Defaults to True.
use_rel_pos – If True, add relative positional embeddings to the attention map. Defaults to False.
rel_pos_zero_init – If True, zero initialize relative positional parameters. Defaults to True.
window_size – Window size for window attention blocks. Defaults to 0.
global_attn_indexes – Indexes for blocks using global attention. Defaults to ().
- img_size
- patch_embed
- blocks
- neck
- forward(x: torch.Tensor) torch.Tensor
Forward pass through the image encoder.
- Parameters:
x – Input image tensor of shape (B, C, H, W).
- Returns:
Encoded features of shape (B, out_chans, H’, W’).
- Return type:
torch.Tensor
- class loki2.models.utils.sam_utils.Block(dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Tuple[int, int] | None = None)
Bases:
torch.nn.ModuleTransformer block with support for window attention and residual propagation.
- Parameters:
dim – Number of input channels.
num_heads – Number of attention heads in each ViT block.
mlp_ratio – Ratio of mlp hidden dim to embedding dim. Defaults to 4.0.
qkv_bias – If True, add a learnable bias to query, key, value. Defaults to True.
norm_layer – Normalization layer. Defaults to nn.LayerNorm.
act_layer – Activation layer. Defaults to nn.GELU.
use_rel_pos – If True, add relative positional embeddings to the attention map. Defaults to False.
rel_pos_zero_init – If True, zero initialize relative positional parameters. Defaults to True.
window_size – Window size for window attention blocks. If it equals 0, then use global attention. Defaults to 0.
input_size – Input resolution for calculating the relative positional parameter size. Defaults to None.
- norm1
- attn
- norm2
- mlp
- window_size
- forward(x: torch.Tensor) torch.Tensor
Forward pass through the transformer block.
- Parameters:
x – Input tensor.
- Returns:
Output tensor after attention and MLP processing.
- Return type:
torch.Tensor
- class loki2.models.utils.sam_utils.Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Tuple[int, int] | None = None)
Bases:
torch.nn.ModuleMulti-head Attention block with relative position embeddings.
- Parameters:
dim – Number of input channels.
num_heads – Number of attention heads. Defaults to 8.
qkv_bias – If True, add a learnable bias to query, key, value. Defaults to True.
use_rel_pos – If True, add relative positional embeddings to the attention map. Defaults to False.
rel_pos_zero_init – If True, zero initialize relative positional parameters. Defaults to True.
input_size – Input resolution for calculating the relative positional parameter size. Defaults to None.
- num_heads
- head_dim
- scale
- qkv
- proj
- use_rel_pos
- forward(x: torch.Tensor) torch.Tensor
Forward pass through the attention block.
- Parameters:
x – Input tensor of shape (B, H, W, C).
- Returns:
Output tensor of shape (B, H, W, C).
- Return type:
torch.Tensor
- loki2.models.utils.sam_utils.window_partition(x: torch.Tensor, window_size: int) Tuple[torch.Tensor, Tuple[int, int]]
Partition input into non-overlapping windows with padding if needed.
- Parameters:
x – Input tokens with shape [B, H, W, C].
window_size – Window size for partitioning.
- Returns:
windows: Windows after partition with shape [B * num_windows, window_size, window_size, C].
(Hp, Wp): Padded height and width before partition.
- Return type:
Tuple[torch.Tensor, Tuple[int, int]]
- loki2.models.utils.sam_utils.window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]) torch.Tensor
Unpartition windows into original sequences and remove padding.
- Parameters:
windows – Input tokens with shape [B * num_windows, window_size, window_size, C].
window_size – Window size.
pad_hw – Padded height and width (Hp, Wp).
hw – Original height and width (H, W) before padding.
- Returns:
Unpartitioned sequences with shape [B, H, W, C].
- Return type:
torch.Tensor
- loki2.models.utils.sam_utils.get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) torch.Tensor
Get relative positional embeddings according to the relative positions of query and key sizes.
- Parameters:
q_size – Size of query q.
k_size – Size of key k.
rel_pos – Relative position embeddings of shape (L, C).
- Returns:
Extracted positional embeddings according to relative positions.
- Return type:
torch.Tensor
- loki2.models.utils.sam_utils.add_decomposed_rel_pos(attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int]) torch.Tensor
Calculate decomposed Relative Positional Embeddings from MViTv2.
- Parameters:
attn – Attention map tensor.
q – Query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h – Relative position embeddings (Lh, C) for height axis.
rel_pos_w – Relative position embeddings (Lw, C) for width axis.
q_size – Spatial sequence size of query q with (q_h, q_w).
k_size – Spatial sequence size of key k with (k_h, k_w).
- Returns:
Attention map with added relative positional embeddings.
- Return type:
torch.Tensor
- class loki2.models.utils.sam_utils.PatchEmbed(kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768)
Bases:
torch.nn.ModuleImage to Patch Embedding module.
Converts input images into patch embeddings using a convolutional layer.
- Parameters:
kernel_size – Kernel size of the projection layer. Defaults to (16, 16).
stride – Stride of the projection layer. Defaults to (16, 16).
padding – Padding size of the projection layer. Defaults to (0, 0).
in_chans – Number of input image channels. Defaults to 3.
embed_dim – Patch embedding dimension. Defaults to 768.
- proj
- forward(x: torch.Tensor) torch.Tensor
Forward pass through patch embedding.
- Parameters:
x – Input image tensor of shape (B, C, H, W).
- Returns:
Patch embeddings of shape (B, H, W, C).
- Return type:
torch.Tensor