loki2.models.cell_segmentation.backbones ======================================== .. py:module:: loki2.models.cell_segmentation.backbones .. autoapi-nested-parse:: Backbone networks for cell segmentation. Different kinds of Vision Transformers for usage as encoders in segmentation networks. Module Contents --------------- .. py:class:: ViTCellViT(extract_layers: List[int], img_size: List[int] = [224], patch_size: int = 16, in_chans: int = 3, num_classes: int = 0, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, qkv_bias: bool = False, qk_scale: Optional[float] = None, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, norm_layer: Callable = nn.LayerNorm, **kwargs) Bases: :py:obj:`loki2.models.base.vision_transformer.VisionTransformer` Vision Transformer with 1D positional embedding for cell segmentation. Extends the base VisionTransformer to support extracting intermediate layer outputs for skip connections. .. py:attribute:: extract_layers .. py:method:: forward(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor] Forward pass with returning intermediate outputs for skip connections. :param x: Input batch tensor. :returns: - Output of last layers (all tokens, without classification) - Classification output - Skip connection outputs from extract_layer selection :rtype: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] .. py:class:: ViTCellViTDeit(extract_layers: List[int], img_size: int = 1024, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 768, depth: int = 12, num_heads: int = 12, mlp_ratio: float = 4, out_chans: int = 256, qkv_bias: bool = True, norm_layer: Type[torch.nn.Module] = nn.LayerNorm, act_layer: Type[torch.nn.Module] = nn.GELU, use_abs_pos: bool = True, use_rel_pos: bool = False, rel_pos_zero_init: bool = True, window_size: int = 0, global_attn_indexes: Tuple[int, Ellipsis] = ()) Bases: :py:obj:`loki2.models.utils.sam_utils.ImageEncoderViT` For a parameter description see ViTCellViT .. py:attribute:: extract_layers .. py:method:: forward(x: torch.Tensor) -> torch.Tensor Forward pass through the image encoder. :param x: Input image tensor of shape (B, C, H, W). :returns: Encoded features of shape (B, out_chans, H', W'). :rtype: torch.Tensor