loki2.models.cell_segmentation.cellvit

CellViT networks and adaptations for cell segmentation.

UNETR paper and code: https://github.com/tamasino52/UNETR SAM paper and code: https://segment-anything.com/

Module Contents

class loki2.models.cell_segmentation.cellvit.CellViT(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: List, mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False)

Bases: torch.nn.Module

CellViT Model for cell segmentation.

U-Net-like network with vision transformer as backbone encoder. Skip connections are shared between branches, but each network has a distinct encoder.

The model has multiple branches:
  • tissue_types: Tissue prediction based on global class token

  • nuclei_binary_map: Binary nuclei prediction

  • hv_map: HV-prediction to separate isolated instances

  • nuclei_type_map: Nuclei instance-prediction

  • [Optional, if regression loss]: regression_map: Regression map for binary prediction

Parameters:
  • num_nuclei_classes – Number of nuclei classes (including background).

  • num_tissue_classes – Number of tissue classes.

  • embed_dim – Embedding dimension of backbone ViT.

  • input_channels – Number of input channels.

  • depth – Depth of the backbone ViT.

  • num_heads – Number of heads of the backbone ViT.

  • extract_layers – List of Transformer Blocks whose outputs should be returned in addition to the tokens. First block starts with 1, maximum is N=depth. Is used for skip connections. At least 4 skip connections need to be returned.

  • mlp_ratio – MLP ratio for hidden MLP dimension of backbone ViT. Defaults to 4.

  • qkv_bias – If bias should be used for query (q), key (k), and value (v) in backbone ViT. Defaults to True.

  • drop_rate – Dropout in MLP. Defaults to 0.

  • attn_drop_rate – Dropout for attention layer in backbone ViT. Defaults to 0.

  • drop_path_rate – Dropout for skip connection. Defaults to 0.

  • regression_loss – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.

patch_size = 16
num_tissue_classes
num_nuclei_classes
embed_dim
input_channels
depth
num_heads
mlp_ratio
qkv_bias
extract_layers
drop_rate
attn_drop_rate
drop_path_rate
encoder
decoder0
decoder1
decoder2
decoder3
regression_loss
offset_branches = 0
branches_output
nuclei_binary_map_decoder
hv_map_decoder
nuclei_type_maps_decoder
forward(x: torch.Tensor, retrieve_tokens: bool = False) dict

Forward pass

Parameters:
  • x (torch.Tensor) – Images in BCHW style

  • retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.

Returns:

Output for all branches:
  • tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)

  • nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)

  • hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)

  • nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)

  • [Optional, if retrieve tokens]: tokens

  • [Optional, if regression loss]:

  • regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)

Return type:

dict

create_upsampling_branch(num_classes: int) torch.nn.Module

Create Upsampling branch

Parameters:

num_classes (int) – Number of output classes

Returns:

Upsampling path

Return type:

nn.Module

calculate_instance_map(predictions: collections.OrderedDict, magnification: Literal[20, 40] = 40) Tuple[torch.Tensor, List[dict]]

Calculate Instance Map from network predictions (after Softmax output)

Parameters:
  • predictions (dict) – Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, 2, H, W) * nuclei_type_map: Type prediction of nuclei. Shape: (B, self.num_nuclei_classes, H, W) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, 2, H, W)

  • magnification (Literal[20, 40], optional) – Which magnification the data has. Defaults to 40.

Returns:

  • torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)

  • List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.

    For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”

Return type:

Tuple[torch.Tensor, List[dict]]

generate_instance_nuclei_map(instance_maps: torch.Tensor, type_preds: List[dict]) torch.Tensor

Convert instance map (binary) to nuclei type instance map

Parameters:
  • instance_maps (torch.Tensor) – Binary instance map, each instance has own integer. Shape: (B, H, W)

  • type_preds (List[dict]) – List (len=B) of dictionary with instance type information (compare post_process_hovernet function for more details)

Returns:

Nuclei type instance map. Shape: (B, self.num_nuclei_classes, H, W)

Return type:

torch.Tensor

freeze_encoder()

Freeze encoder to not train it

unfreeze_encoder()

Unfreeze encoder to train the whole model

class loki2.models.cell_segmentation.cellvit.DataclassHVStorage

Storing UniSeg Prediction/GT objects for calculating loss, metrics etc. with HoverNet networks

Parameters:
  • nuclei_binary_map (torch.Tensor) – Softmax output for binary nuclei branch. Shape: (batch_size, 2, H, W)

  • hv_map (torch.Tensor) – Logit output for HV-Map. Shape: (batch_size, 2, H, W)

  • nuclei_type_map (torch.Tensor) – Softmax output for nuclei type-prediction. Shape: (batch_size, num_tissue_classes, H, W)

  • tissue_types (torch.Tensor) – Logit tissue prediction output. Shape: (batch_size, num_tissue_classes)

  • instance_map (torch.Tensor) – Pixel-wise nuclear instance segmentation. Each instance has its own integer, starting from 1. Shape: (batch_size, H, W)

  • instance_types_nuclei – Pixel-wise nuclear instance segmentation predictions, for each nuclei type. Each instance has its own integer, starting from 1. Shape: (batch_size, num_nuclei_classes, H, W)

  • batch_size (int) – Batch size of the experiment

  • instance_types (list, optional) – Instance type prediction list. Each list entry stands for one image. Each list entry is a dictionary with the following structure: Main Key is the nuclei instance number (int), with a dict as value. For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), contour, type_prob (probability), type (nuclei type) Defaults to None.

  • regression_map (torch.Tensor, optional) – Regression map for binary prediction map. Shape: (batch_size, 2, H, W). Defaults to None.

  • regression_loss (bool, optional) – Indicating if regression map is present. Defaults to False.

  • h (int, optional) – Height of used input images. Defaults to 256.

  • w (int, optional) – Width of used input images. Defaults to 256.

  • num_tissue_classes (int, optional) – Number of tissue classes in the data. Defaults to 19.

  • num_nuclei_classes (int, optional) – Number of nuclei types in the data (including background). Defaults to 6.

nuclei_binary_map: torch.Tensor
hv_map: torch.Tensor
tissue_types: torch.Tensor
nuclei_type_map: torch.Tensor
instance_map: torch.Tensor
instance_types_nuclei: torch.Tensor
batch_size: int
instance_types: list = None
regression_map: torch.Tensor = None
regression_loss: bool = False
h: int = 256
w: int = 256
num_tissue_classes: int = 19
num_nuclei_classes: int = 6
get_dict() dict

Return dictionary of entries