loki2.models.cell_segmentation.cellvit
CellViT networks and adaptations for cell segmentation.
UNETR paper and code: https://github.com/tamasino52/UNETR SAM paper and code: https://segment-anything.com/
Module Contents
- class loki2.models.cell_segmentation.cellvit.CellViT(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: List, mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False)
Bases:
torch.nn.ModuleCellViT Model for cell segmentation.
U-Net-like network with vision transformer as backbone encoder. Skip connections are shared between branches, but each network has a distinct encoder.
- The model has multiple branches:
tissue_types: Tissue prediction based on global class token
nuclei_binary_map: Binary nuclei prediction
hv_map: HV-prediction to separate isolated instances
nuclei_type_map: Nuclei instance-prediction
[Optional, if regression loss]: regression_map: Regression map for binary prediction
- Parameters:
num_nuclei_classes – Number of nuclei classes (including background).
num_tissue_classes – Number of tissue classes.
embed_dim – Embedding dimension of backbone ViT.
input_channels – Number of input channels.
depth – Depth of the backbone ViT.
num_heads – Number of heads of the backbone ViT.
extract_layers – List of Transformer Blocks whose outputs should be returned in addition to the tokens. First block starts with 1, maximum is N=depth. Is used for skip connections. At least 4 skip connections need to be returned.
mlp_ratio – MLP ratio for hidden MLP dimension of backbone ViT. Defaults to 4.
qkv_bias – If bias should be used for query (q), key (k), and value (v) in backbone ViT. Defaults to True.
drop_rate – Dropout in MLP. Defaults to 0.
attn_drop_rate – Dropout for attention layer in backbone ViT. Defaults to 0.
drop_path_rate – Dropout for skip connection. Defaults to 0.
regression_loss – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
- patch_size = 16
- num_tissue_classes
- num_nuclei_classes
- embed_dim
- input_channels
- depth
- num_heads
- mlp_ratio
- qkv_bias
- extract_layers
- drop_rate
- attn_drop_rate
- drop_path_rate
- encoder
- decoder0
- decoder1
- decoder2
- decoder3
- regression_loss
- offset_branches = 0
- branches_output
- nuclei_binary_map_decoder
- hv_map_decoder
- nuclei_type_maps_decoder
- forward(x: torch.Tensor, retrieve_tokens: bool = False) dict
Forward pass
- Parameters:
x (torch.Tensor) – Images in BCHW style
retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.
- Returns:
- Output for all branches:
tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
[Optional, if retrieve tokens]: tokens
[Optional, if regression loss]:
regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
- Return type:
dict
- create_upsampling_branch(num_classes: int) torch.nn.Module
Create Upsampling branch
- Parameters:
num_classes (int) – Number of output classes
- Returns:
Upsampling path
- Return type:
nn.Module
- calculate_instance_map(predictions: collections.OrderedDict, magnification: Literal[20, 40] = 40) Tuple[torch.Tensor, List[dict]]
Calculate Instance Map from network predictions (after Softmax output)
- Parameters:
predictions (dict) – Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, 2, H, W) * nuclei_type_map: Type prediction of nuclei. Shape: (B, self.num_nuclei_classes, H, W) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, 2, H, W)
magnification (Literal[20, 40], optional) – Which magnification the data has. Defaults to 40.
- Returns:
torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)
- List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.
For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”
- Return type:
Tuple[torch.Tensor, List[dict]]
- generate_instance_nuclei_map(instance_maps: torch.Tensor, type_preds: List[dict]) torch.Tensor
Convert instance map (binary) to nuclei type instance map
- Parameters:
instance_maps (torch.Tensor) – Binary instance map, each instance has own integer. Shape: (B, H, W)
type_preds (List[dict]) – List (len=B) of dictionary with instance type information (compare post_process_hovernet function for more details)
- Returns:
Nuclei type instance map. Shape: (B, self.num_nuclei_classes, H, W)
- Return type:
torch.Tensor
- freeze_encoder()
Freeze encoder to not train it
- unfreeze_encoder()
Unfreeze encoder to train the whole model
- class loki2.models.cell_segmentation.cellvit.DataclassHVStorage
Storing UniSeg Prediction/GT objects for calculating loss, metrics etc. with HoverNet networks
- Parameters:
nuclei_binary_map (torch.Tensor) – Softmax output for binary nuclei branch. Shape: (batch_size, 2, H, W)
hv_map (torch.Tensor) – Logit output for HV-Map. Shape: (batch_size, 2, H, W)
nuclei_type_map (torch.Tensor) – Softmax output for nuclei type-prediction. Shape: (batch_size, num_tissue_classes, H, W)
tissue_types (torch.Tensor) – Logit tissue prediction output. Shape: (batch_size, num_tissue_classes)
instance_map (torch.Tensor) – Pixel-wise nuclear instance segmentation. Each instance has its own integer, starting from 1. Shape: (batch_size, H, W)
instance_types_nuclei – Pixel-wise nuclear instance segmentation predictions, for each nuclei type. Each instance has its own integer, starting from 1. Shape: (batch_size, num_nuclei_classes, H, W)
batch_size (int) – Batch size of the experiment
instance_types (list, optional) – Instance type prediction list. Each list entry stands for one image. Each list entry is a dictionary with the following structure: Main Key is the nuclei instance number (int), with a dict as value. For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), contour, type_prob (probability), type (nuclei type) Defaults to None.
regression_map (torch.Tensor, optional) – Regression map for binary prediction map. Shape: (batch_size, 2, H, W). Defaults to None.
regression_loss (bool, optional) – Indicating if regression map is present. Defaults to False.
h (int, optional) – Height of used input images. Defaults to 256.
w (int, optional) – Width of used input images. Defaults to 256.
num_tissue_classes (int, optional) – Number of tissue classes in the data. Defaults to 19.
num_nuclei_classes (int, optional) – Number of nuclei types in the data (including background). Defaults to 6.
- nuclei_binary_map: torch.Tensor
- hv_map: torch.Tensor
- tissue_types: torch.Tensor
- nuclei_type_map: torch.Tensor
- instance_map: torch.Tensor
- instance_types_nuclei: torch.Tensor
- batch_size: int
- instance_types: list = None
- regression_map: torch.Tensor = None
- regression_loss: bool = False
- h: int = 256
- w: int = 256
- num_tissue_classes: int = 19
- num_nuclei_classes: int = 6
- get_dict() dict
Return dictionary of entries