loki2.models.utils.sam_utils

SAM (Segment Anything Model) utilities for Loki2.

This module provides utilities adapted from Meta’s Segment Anything Model (SAM) for use as backbones in cell segmentation networks.

Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.

This source code is licensed under the license found in the LICENSE file in the root directory of this source tree.

This file has not been changed by the author of this repository

@ Fabian Hörst, fabian.hoerst@uk-essen.de Institute for Artifical Intelligence in Medicine, University Medicine Essen

Module Contents

class loki2.models.utils.sam_utils.MLPBlock(embedding_dim: int, mlp_dim: int, act: Type[torch.nn.Module] = nn.GELU)

Bases: torch.nn.Module

Multi-layer Perceptron block with two linear layers and activation.

Parameters:
  • embedding_dim – Input and output embedding dimension.

  • mlp_dim – Hidden dimension of the MLP.

  • act – Activation function class. Defaults to nn.GELU.

lin1
lin2
act
forward(x: torch.Tensor) torch.Tensor
class loki2.models.utils.sam_utils.LayerNorm2d(num_channels: int, eps: float = 1e-06)

Bases: torch.nn.Module

2D Layer Normalization for image tensors.

Parameters:
  • num_channels – Number of channels in the input tensor.

  • eps – Small epsilon value for numerical stability. Defaults to 1e-6.

weight
bias
eps
forward(x: torch.Tensor) torch.Tensor
class loki2.models.utils.sam_utils.ImageEncoderViT(img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4.0, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, Ellipsis] = ())

Bases: torch.nn.Module

Vision Transformer image encoder adapted from SAM/ViTDet.

This class and its supporting functions are lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py

Parameters:
  • img_size – Input image size. Defaults to 1024.

  • patch_size – Patch size. Defaults to 16.

  • in_chans – Number of input image channels. Defaults to 3.

  • embed_dim – Patch embedding dimension. Defaults to 768.

  • depth – Depth of ViT. Defaults to 12.

  • num_heads – Number of attention heads in each ViT block. Defaults to 12.

  • mlp_ratio – Ratio of mlp hidden dim to embedding dim. Defaults to 4.0.

  • out_chans – Output channels. Defaults to 256.

  • qkv_bias – If True, add a learnable bias to query, key, value. Defaults to True.

  • norm_layer – Normalization layer. Defaults to nn.LayerNorm.

  • act_layer – Activation layer. Defaults to nn.GELU.

  • use_abs_pos – If True, use absolute positional embeddings. Defaults to True.

  • use_rel_pos – If True, add relative positional embeddings to the attention map. Defaults to False.

  • rel_pos_zero_init – If True, zero initialize relative positional parameters. Defaults to True.

  • window_size – Window size for window attention blocks. Defaults to 0.

  • global_attn_indexes – Indexes for blocks using global attention. Defaults to ().

img_size
patch_embed
pos_embed: torch.nn.Parameter | None = None
blocks
neck
forward(x: torch.Tensor) torch.Tensor

Forward pass through the image encoder.

Parameters:

x – Input image tensor of shape (B, C, H, W).

Returns:

Encoded features of shape (B, out_chans, H’, W’).

Return type:

torch.Tensor

class loki2.models.utils.sam_utils.Block(dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, input_size: Tuple[int, int] | None = None)

Bases: torch.nn.Module

Transformer block with support for window attention and residual propagation.

Parameters:
  • dim – Number of input channels.

  • num_heads – Number of attention heads in each ViT block.

  • mlp_ratio – Ratio of mlp hidden dim to embedding dim. Defaults to 4.0.

  • qkv_bias – If True, add a learnable bias to query, key, value. Defaults to True.

  • norm_layer – Normalization layer. Defaults to nn.LayerNorm.

  • act_layer – Activation layer. Defaults to nn.GELU.

  • use_rel_pos – If True, add relative positional embeddings to the attention map. Defaults to False.

  • rel_pos_zero_init – If True, zero initialize relative positional parameters. Defaults to True.

  • window_size – Window size for window attention blocks. If it equals 0, then use global attention. Defaults to 0.

  • input_size – Input resolution for calculating the relative positional parameter size. Defaults to None.

norm1
attn
norm2
mlp
window_size
forward(x: torch.Tensor) torch.Tensor

Forward pass through the transformer block.

Parameters:

x – Input tensor.

Returns:

Output tensor after attention and MLP processing.

Return type:

torch.Tensor

class loki2.models.utils.sam_utils.Attention(dim: int, num_heads: int = 8, qkv_bias: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, input_size: Tuple[int, int] | None = None)

Bases: torch.nn.Module

Multi-head Attention block with relative position embeddings.

Parameters:
  • dim – Number of input channels.

  • num_heads – Number of attention heads. Defaults to 8.

  • qkv_bias – If True, add a learnable bias to query, key, value. Defaults to True.

  • use_rel_pos – If True, add relative positional embeddings to the attention map. Defaults to False.

  • rel_pos_zero_init – If True, zero initialize relative positional parameters. Defaults to True.

  • input_size – Input resolution for calculating the relative positional parameter size. Defaults to None.

num_heads
head_dim
scale
qkv
proj
use_rel_pos
forward(x: torch.Tensor) torch.Tensor

Forward pass through the attention block.

Parameters:

x – Input tensor of shape (B, H, W, C).

Returns:

Output tensor of shape (B, H, W, C).

Return type:

torch.Tensor

loki2.models.utils.sam_utils.window_partition(x: torch.Tensor, window_size: int) Tuple[torch.Tensor, Tuple[int, int]]

Partition input into non-overlapping windows with padding if needed.

Parameters:
  • x – Input tokens with shape [B, H, W, C].

  • window_size – Window size for partitioning.

Returns:

  • windows: Windows after partition with shape [B * num_windows, window_size, window_size, C].

  • (Hp, Wp): Padded height and width before partition.

Return type:

Tuple[torch.Tensor, Tuple[int, int]]

loki2.models.utils.sam_utils.window_unpartition(windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]) torch.Tensor

Unpartition windows into original sequences and remove padding.

Parameters:
  • windows – Input tokens with shape [B * num_windows, window_size, window_size, C].

  • window_size – Window size.

  • pad_hw – Padded height and width (Hp, Wp).

  • hw – Original height and width (H, W) before padding.

Returns:

Unpartitioned sequences with shape [B, H, W, C].

Return type:

torch.Tensor

loki2.models.utils.sam_utils.get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) torch.Tensor

Get relative positional embeddings according to the relative positions of query and key sizes.

Parameters:
  • q_size – Size of query q.

  • k_size – Size of key k.

  • rel_pos – Relative position embeddings of shape (L, C).

Returns:

Extracted positional embeddings according to relative positions.

Return type:

torch.Tensor

loki2.models.utils.sam_utils.add_decomposed_rel_pos(attn: torch.Tensor, q: torch.Tensor, rel_pos_h: torch.Tensor, rel_pos_w: torch.Tensor, q_size: Tuple[int, int], k_size: Tuple[int, int]) torch.Tensor

Calculate decomposed Relative Positional Embeddings from MViTv2.

Reference: https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

Parameters:
  • attn – Attention map tensor.

  • q – Query q in the attention layer with shape (B, q_h * q_w, C).

  • rel_pos_h – Relative position embeddings (Lh, C) for height axis.

  • rel_pos_w – Relative position embeddings (Lw, C) for width axis.

  • q_size – Spatial sequence size of query q with (q_h, q_w).

  • k_size – Spatial sequence size of key k with (k_h, k_w).

Returns:

Attention map with added relative positional embeddings.

Return type:

torch.Tensor

class loki2.models.utils.sam_utils.PatchEmbed(kernel_size: Tuple[int, int] = (16, 16), stride: Tuple[int, int] = (16, 16), padding: Tuple[int, int] = (0, 0), in_chans: int = 3, embed_dim: int = 768)

Bases: torch.nn.Module

Image to Patch Embedding module.

Converts input images into patch embeddings using a convolutional layer.

Parameters:
  • kernel_size – Kernel size of the projection layer. Defaults to (16, 16).

  • stride – Stride of the projection layer. Defaults to (16, 16).

  • padding – Padding size of the projection layer. Defaults to (0, 0).

  • in_chans – Number of input image channels. Defaults to 3.

  • embed_dim – Patch embedding dimension. Defaults to 768.

proj
forward(x: torch.Tensor) torch.Tensor

Forward pass through patch embedding.

Parameters:

x – Input image tensor of shape (B, C, H, W).

Returns:

Patch embeddings of shape (B, H, W, C).

Return type:

torch.Tensor