loki2.models.cell_segmentation.cellvit ====================================== .. py:module:: loki2.models.cell_segmentation.cellvit .. autoapi-nested-parse:: CellViT networks and adaptations for cell segmentation. UNETR paper and code: https://github.com/tamasino52/UNETR SAM paper and code: https://segment-anything.com/ Module Contents --------------- .. py:class:: CellViT(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: List, mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False) Bases: :py:obj:`torch.nn.Module` CellViT Model for cell segmentation. U-Net-like network with vision transformer as backbone encoder. Skip connections are shared between branches, but each network has a distinct encoder. The model has multiple branches: * tissue_types: Tissue prediction based on global class token * nuclei_binary_map: Binary nuclei prediction * hv_map: HV-prediction to separate isolated instances * nuclei_type_map: Nuclei instance-prediction * [Optional, if regression loss]: regression_map: Regression map for binary prediction :param num_nuclei_classes: Number of nuclei classes (including background). :param num_tissue_classes: Number of tissue classes. :param embed_dim: Embedding dimension of backbone ViT. :param input_channels: Number of input channels. :param depth: Depth of the backbone ViT. :param num_heads: Number of heads of the backbone ViT. :param extract_layers: List of Transformer Blocks whose outputs should be returned in addition to the tokens. First block starts with 1, maximum is N=depth. Is used for skip connections. At least 4 skip connections need to be returned. :param mlp_ratio: MLP ratio for hidden MLP dimension of backbone ViT. Defaults to 4. :param qkv_bias: If bias should be used for query (q), key (k), and value (v) in backbone ViT. Defaults to True. :param drop_rate: Dropout in MLP. Defaults to 0. :param attn_drop_rate: Dropout for attention layer in backbone ViT. Defaults to 0. :param drop_path_rate: Dropout for skip connection. Defaults to 0. :param regression_loss: Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False. .. py:attribute:: patch_size :value: 16 .. py:attribute:: num_tissue_classes .. py:attribute:: num_nuclei_classes .. py:attribute:: embed_dim .. py:attribute:: input_channels .. py:attribute:: depth .. py:attribute:: num_heads .. py:attribute:: mlp_ratio .. py:attribute:: qkv_bias .. py:attribute:: extract_layers .. py:attribute:: drop_rate .. py:attribute:: attn_drop_rate .. py:attribute:: drop_path_rate .. py:attribute:: encoder .. py:attribute:: decoder0 .. py:attribute:: decoder1 .. py:attribute:: decoder2 .. py:attribute:: decoder3 .. py:attribute:: regression_loss .. py:attribute:: offset_branches :value: 0 .. py:attribute:: branches_output .. py:attribute:: nuclei_binary_map_decoder .. py:attribute:: hv_map_decoder .. py:attribute:: nuclei_type_maps_decoder .. py:method:: forward(x: torch.Tensor, retrieve_tokens: bool = False) -> dict Forward pass :param x: Images in BCHW style :type x: torch.Tensor :param retrieve_tokens: If tokens of ViT should be returned as well. Defaults to False. :type retrieve_tokens: bool, optional :returns: Output for all branches: * tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes) * nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W) * hv_map: Binary HV Map predictions. Shape: (B, 2, H, W) * nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W) * [Optional, if retrieve tokens]: tokens * [Optional, if regression loss]: * regression_map: Regression map for binary prediction. Shape: (B, 2, H, W) :rtype: dict .. py:method:: create_upsampling_branch(num_classes: int) -> torch.nn.Module Create Upsampling branch :param num_classes: Number of output classes :type num_classes: int :returns: Upsampling path :rtype: nn.Module .. py:method:: calculate_instance_map(predictions: collections.OrderedDict, magnification: Literal[20, 40] = 40) -> Tuple[torch.Tensor, List[dict]] Calculate Instance Map from network predictions (after Softmax output) :param predictions: Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, 2, H, W) * nuclei_type_map: Type prediction of nuclei. Shape: (B, self.num_nuclei_classes, H, W) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, 2, H, W) :type predictions: dict :param magnification: Which magnification the data has. Defaults to 40. :type magnification: Literal[20, 40], optional :returns: * torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W) * List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus. For each nucleus, the following information are returned: "bbox", "centroid", "contour", "type_prob", "type" :rtype: Tuple[torch.Tensor, List[dict]] .. py:method:: generate_instance_nuclei_map(instance_maps: torch.Tensor, type_preds: List[dict]) -> torch.Tensor Convert instance map (binary) to nuclei type instance map :param instance_maps: Binary instance map, each instance has own integer. Shape: (B, H, W) :type instance_maps: torch.Tensor :param type_preds: List (len=B) of dictionary with instance type information (compare post_process_hovernet function for more details) :type type_preds: List[dict] :returns: Nuclei type instance map. Shape: (B, self.num_nuclei_classes, H, W) :rtype: torch.Tensor .. py:method:: freeze_encoder() Freeze encoder to not train it .. py:method:: unfreeze_encoder() Unfreeze encoder to train the whole model .. py:class:: DataclassHVStorage Storing UniSeg Prediction/GT objects for calculating loss, metrics etc. with HoverNet networks :param nuclei_binary_map: Softmax output for binary nuclei branch. Shape: (batch_size, 2, H, W) :type nuclei_binary_map: torch.Tensor :param hv_map: Logit output for HV-Map. Shape: (batch_size, 2, H, W) :type hv_map: torch.Tensor :param nuclei_type_map: Softmax output for nuclei type-prediction. Shape: (batch_size, num_tissue_classes, H, W) :type nuclei_type_map: torch.Tensor :param tissue_types: Logit tissue prediction output. Shape: (batch_size, num_tissue_classes) :type tissue_types: torch.Tensor :param instance_map: Pixel-wise nuclear instance segmentation. Each instance has its own integer, starting from 1. Shape: (batch_size, H, W) :type instance_map: torch.Tensor :param instance_types_nuclei: Pixel-wise nuclear instance segmentation predictions, for each nuclei type. Each instance has its own integer, starting from 1. Shape: (batch_size, num_nuclei_classes, H, W) :param batch_size: Batch size of the experiment :type batch_size: int :param instance_types: Instance type prediction list. Each list entry stands for one image. Each list entry is a dictionary with the following structure: Main Key is the nuclei instance number (int), with a dict as value. For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), contour, type_prob (probability), type (nuclei type) Defaults to None. :type instance_types: list, optional :param regression_map: Regression map for binary prediction map. Shape: (batch_size, 2, H, W). Defaults to None. :type regression_map: torch.Tensor, optional :param regression_loss: Indicating if regression map is present. Defaults to False. :type regression_loss: bool, optional :param h: Height of used input images. Defaults to 256. :type h: int, optional :param w: Width of used input images. Defaults to 256. :type w: int, optional :param num_tissue_classes: Number of tissue classes in the data. Defaults to 19. :type num_tissue_classes: int, optional :param num_nuclei_classes: Number of nuclei types in the data (including background). Defaults to 6. :type num_nuclei_classes: int, optional .. py:attribute:: nuclei_binary_map :type: torch.Tensor .. py:attribute:: hv_map :type: torch.Tensor .. py:attribute:: tissue_types :type: torch.Tensor .. py:attribute:: nuclei_type_map :type: torch.Tensor .. py:attribute:: instance_map :type: torch.Tensor .. py:attribute:: instance_types_nuclei :type: torch.Tensor .. py:attribute:: batch_size :type: int .. py:attribute:: instance_types :type: list :value: None .. py:attribute:: regression_map :type: torch.Tensor :value: None .. py:attribute:: regression_loss :type: bool :value: False .. py:attribute:: h :type: int :value: 256 .. py:attribute:: w :type: int :value: 256 .. py:attribute:: num_tissue_classes :type: int :value: 19 .. py:attribute:: num_nuclei_classes :type: int :value: 6 .. py:method:: get_dict() -> dict Return dictionary of entries