loki2.models.cell_segmentation.loki2_sam
CellViT-SAM (B, L, H) model.
This module provides CellViT variants using Segment Anything Model (SAM) backbones.
Module Contents
- class loki2.models.cell_segmentation.loki2_sam.CellViTSAM(model_path: pathlib.Path | str, num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)
Bases:
loki2.models.cell_segmentation.cellvit.CellViTCellViT with SAM backbone settings.
Skip connections are shared between branches, but each network has a distinct encoder.
- Parameters:
model_path – Path to pretrained SAM model.
num_nuclei_classes – Number of nuclei classes (including background).
num_tissue_classes – Number of tissue classes.
vit_structure – SAM model type (“SAM-B”, “SAM-L”, or “SAM-H”).
drop_rate – Dropout in MLP. Defaults to 0.
regression_loss – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
- Raises:
NotImplementedError – If SAM configuration is unknown.
- input_channels = 3
- mlp_ratio = 4
- qkv_bias = True
- num_nuclei_classes
- model_path
- prompt_embed_dim = 256
- encoder
- classifier_head
- load_pretrained_encoder(model_path)
Load pretrained SAM encoder from provided path
- Parameters:
model_path (str) – Path to SAM model
- forward(x: torch.Tensor, retrieve_tokens: bool = False)
Forward pass
- Parameters:
x (torch.Tensor) – Images in BCHW style
retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.
- Returns:
- Output for all branches:
tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
[Optional, if retrieve tokens]: tokens
[Optional, if regression loss]:
regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
- Return type:
dict
- init_vit_b()
- init_vit_l()
- init_vit_h()