loki2.models.cell_segmentation.backbones

Backbone networks for cell segmentation.

Different kinds of Vision Transformers for usage as encoders in segmentation networks.

Module Contents

class loki2.models.cell_segmentation.backbones.ViTCellViT(extract_layers: List[int], img_size: List[int] = [224], patch_size: int = 16, in_chans: int = 3, num_classes: int = 0, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, qkv_bias: bool = False, qk_scale: float | None = None, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, norm_layer: Callable = nn.LayerNorm, **kwargs)

Bases: loki2.models.base.vision_transformer.VisionTransformer

Vision Transformer with 1D positional embedding for cell segmentation.

Extends the base VisionTransformer to support extracting intermediate layer outputs for skip connections.

extract_layers
forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Forward pass with returning intermediate outputs for skip connections.

Parameters:

x – Input batch tensor.

Returns:

  • Output of last layers (all tokens, without classification)

  • Classification output

  • Skip connection outputs from extract_layer selection

Return type:

Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

class loki2.models.cell_segmentation.backbones.ViTCellViTDeit(extract_layers: List[int], img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, Ellipsis] = ())

Bases: loki2.models.utils.sam_utils.ImageEncoderViT

For a parameter description see ViTCellViT

extract_layers
forward(x: torch.Tensor) torch.Tensor

Forward pass through the image encoder.

Parameters:

x – Input image tensor of shape (B, C, H, W).

Returns:

Encoded features of shape (B, out_chans, H’, W’).

Return type:

torch.Tensor