loki2.models.cell_segmentation.backbones
Backbone networks for cell segmentation.
Different kinds of Vision Transformers for usage as encoders in segmentation networks.
Module Contents
- class loki2.models.cell_segmentation.backbones.ViTCellViT(extract_layers: List[int], img_size: List[int] = [224], patch_size: int = 16, in_chans: int = 3, num_classes: int = 0, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, qkv_bias: bool = False, qk_scale: float | None = None, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, norm_layer: Callable = nn.LayerNorm, **kwargs)
Bases:
loki2.models.base.vision_transformer.VisionTransformerVision Transformer with 1D positional embedding for cell segmentation.
Extends the base VisionTransformer to support extracting intermediate layer outputs for skip connections.
- extract_layers
- forward(x: torch.Tensor) Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Forward pass with returning intermediate outputs for skip connections.
- Parameters:
x – Input batch tensor.
- Returns:
Output of last layers (all tokens, without classification)
Classification output
Skip connection outputs from extract_layer selection
- Return type:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- class loki2.models.cell_segmentation.backbones.ViTCellViTDeit(extract_layers: List[int], img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, Ellipsis] = ())
Bases:
loki2.models.utils.sam_utils.ImageEncoderViTFor a parameter description see ViTCellViT
- extract_layers
- forward(x: torch.Tensor) torch.Tensor
Forward pass through the image encoder.
- Parameters:
x – Input image tensor of shape (B, C, H, W).
- Returns:
Encoded features of shape (B, out_chans, H’, W’).
- Return type:
torch.Tensor