loki2.models.cell_segmentation.loki2_sam ======================================== .. py:module:: loki2.models.cell_segmentation.loki2_sam .. autoapi-nested-parse:: CellViT-SAM (B, L, H) model. This module provides CellViT variants using Segment Anything Model (SAM) backbones. Module Contents --------------- .. py:class:: CellViTSAM(model_path: Union[pathlib.Path, str], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False) Bases: :py:obj:`loki2.models.cell_segmentation.cellvit.CellViT` CellViT with SAM backbone settings. Skip connections are shared between branches, but each network has a distinct encoder. :param model_path: Path to pretrained SAM model. :param num_nuclei_classes: Number of nuclei classes (including background). :param num_tissue_classes: Number of tissue classes. :param vit_structure: SAM model type ("SAM-B", "SAM-L", or "SAM-H"). :param drop_rate: Dropout in MLP. Defaults to 0. :param regression_loss: Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False. :raises NotImplementedError: If SAM configuration is unknown. .. py:attribute:: input_channels :value: 3 .. py:attribute:: mlp_ratio :value: 4 .. py:attribute:: qkv_bias :value: True .. py:attribute:: num_nuclei_classes .. py:attribute:: model_path .. py:attribute:: prompt_embed_dim :value: 256 .. py:attribute:: encoder .. py:attribute:: classifier_head .. py:method:: load_pretrained_encoder(model_path) Load pretrained SAM encoder from provided path :param model_path: Path to SAM model :type model_path: str .. py:method:: forward(x: torch.Tensor, retrieve_tokens: bool = False) Forward pass :param x: Images in BCHW style :type x: torch.Tensor :param retrieve_tokens: If tokens of ViT should be returned as well. Defaults to False. :type retrieve_tokens: bool, optional :returns: Output for all branches: * tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes) * nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W) * hv_map: Binary HV Map predictions. Shape: (B, 2, H, W) * nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W) * [Optional, if retrieve tokens]: tokens * [Optional, if regression loss]: * regression_map: Regression map for binary prediction. Shape: (B, 2, H, W) :rtype: dict .. py:method:: init_vit_b() .. py:method:: init_vit_l() .. py:method:: init_vit_h()