loki2.models.cell_segmentation.loki2_sam

CellViT-SAM (B, L, H) model.

This module provides CellViT variants using Segment Anything Model (SAM) backbones.

Module Contents

class loki2.models.cell_segmentation.loki2_sam.CellViTSAM(model_path: pathlib.Path | str, num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)

Bases: loki2.models.cell_segmentation.cellvit.CellViT

CellViT with SAM backbone settings.

Skip connections are shared between branches, but each network has a distinct encoder.

Parameters:
  • model_path – Path to pretrained SAM model.

  • num_nuclei_classes – Number of nuclei classes (including background).

  • num_tissue_classes – Number of tissue classes.

  • vit_structure – SAM model type (“SAM-B”, “SAM-L”, or “SAM-H”).

  • drop_rate – Dropout in MLP. Defaults to 0.

  • regression_loss – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.

Raises:

NotImplementedError – If SAM configuration is unknown.

input_channels = 3
mlp_ratio = 4
qkv_bias = True
num_nuclei_classes
model_path
prompt_embed_dim = 256
encoder
classifier_head
load_pretrained_encoder(model_path)

Load pretrained SAM encoder from provided path

Parameters:

model_path (str) – Path to SAM model

forward(x: torch.Tensor, retrieve_tokens: bool = False)

Forward pass

Parameters:
  • x (torch.Tensor) – Images in BCHW style

  • retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.

Returns:

Output for all branches:
  • tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)

  • nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)

  • hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)

  • nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)

  • [Optional, if retrieve tokens]: tokens

  • [Optional, if regression loss]:

  • regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)

Return type:

dict

init_vit_b()
init_vit_l()
init_vit_h()