loki2.inference.postprocessing_cupy
Postprocessing of Loki2 network output, tailored for the inference pipeline.
This module provides GPU-accelerated postprocessing using CuPy for efficient cell detection and segmentation from network predictions.
Module Contents
- class loki2.inference.postprocessing_cupy.DetectionCellPostProcessorCupy(wsi: loki2.data.dataclass.wsi.WSI | loki2.data.dataclass.wsi.WSIMetadata, nr_types: int, resolution: float = 0.25, classifier: torch.nn.Module = None, binary: bool = False, gt: bool = False)
- wsi
- nr_types
- resolution
- classifier
- gt
- binary
- check_network_output(predictions_: Dict[str, torch.Tensor]) None
Check if the network output is valid.
- Parameters:
predictions –
Network predictions dictionary with required keys: * nuclei_binary_map: Binary nucleus predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei.
Shape: (B, H, W, self.nr_types)
hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2)
- Raises:
AssertionError – If predictions dictionary is invalid or missing required keys.
- post_process_batch(predictions_: dict) Tuple[torch.Tensor, List[dict]]
Post process a batch of predictions and generate cell dictionary and instance predictions for each image in a list
- Parameters:
predictions (dict) – Network predictions with tokens. Keys (required): * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (B, H, W, self.num_nuclei_classes,) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2)
- Returns:
torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)
- List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.
For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”
- Return type:
Tuple[torch.Tensor, List[dict]]
- post_process_single_image(pred_map: cupy.ndarray) Tuple[numpy.ndarray, dict[int, dict]]
Process one single image and generate cell dictionary and instance predictions
- Parameters:
pred_map (cp.ndarray) – Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4)
- Returns:
_description_
- Return type:
Tuple[np.ndarray, dict[int, dict]]
- class loki2.inference.postprocessing_cupy.BatchPoolingActor(detection_cell_postprocessor: DetectionCellPostProcessorCupy, run_conf: Dict[str, Any])
Ray Actor for coordinating the postprocessing of batches.
The postprocessing is done in a separate process to avoid blocking the main process. The calculation is done with the help of the DetectionCellPostProcessorCupy class. This actor acts as a coordinator for the postprocessing of one batch and a wrapper for the DetectionCellPostProcessorCupy class.
- detection_cell_postprocessor
- run_conf
- convert_batch_to_graph_nodes(predictions: Dict[str, torch.Tensor], metadata: List[Dict[str, Any]]) Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[torch.Tensor], List[torch.Tensor]]
Postprocess a batch of predictions and convert it to graph nodes.
Returns the complete graph nodes (cell dictionary), the detection nodes (cell detection dictionary), the cell tokens and the cell positions.
- Parameters:
predictions –
Network predictions dictionary with required keys: * nuclei_binary_map: Binary nucleus predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei.
Shape: (B, H, W, self.num_nuclei_classes)
hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2)
metadata – List of metadata dictionaries for each patch. Each dictionary needs to contain: * row: Row index of the patch * col: Column index of the patch Other keys are optional.
- Returns:
Complete graph nodes (cell dictionaries)
Detection nodes (cell detection dictionaries)
Cell tokens
Cell positions
- Return type:
Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[torch.Tensor], List[torch.Tensor]]
- convert_patch_to_graph_nodes(patch_cell_dict: dict, patch_metadata: dict, patch_tokens: torch.Tensor) Tuple[List[dict], List[dict], List[torch.Tensor], List[torch.Tensor]]
Extract information from a single patch and convert it to graph nodes for a global view
- Parameters:
patch_cell_dict (dict) – Dictionary containing the cell information. Each dictionary needs to contain the following keys: * bbox: Bounding box of the cell * centroid: Centroid of the cell * contour: Contour of the cell * type_prob: Probability of the cell type * type: Type of the cell
patch_metadata (dict) – Metadata dictionary for the patch. Each dictionary needs to contain the following keys: * row: Row index of the patch * col: Column index of the patch Other keys are optional but are stored in the graph nodes for later use
patch_tokens (torch.Tensor) – Tokens of the patch. Shape: (D, H, W)
- Returns:
List[dict]: Complete graph nodes (cell dictionary) of the patch
List[dict]: Detection nodes (cell detection dictionary) of the patch
List[torch.Tensor]: Cell tokens of the patch
List[torch.Tensor]: Cell positions (centroid) of the patch
- Return type:
Tuple[List[dict], List[dict], List[torch.Tensor], List[torch.Tensor]]
- loki2.inference.postprocessing_cupy.get_cell_position(bbox: numpy.ndarray, patch_size: int = 1024) List[int]
Get cell position as a list indicating which borders the cell touches.
Entry is 1 if cell touches the border: [top, right, down, left]
- Parameters:
bbox – Bounding box of cell as array of shape (2, 2) in (h, w) format.
patch_size – Patch size. Defaults to 1024.
- Returns:
- List with 4 integers [top, right, down, left] indicating
border contact (1) or not (0).
- Return type:
List[int]
- loki2.inference.postprocessing_cupy.get_cell_position_marging(bbox: numpy.ndarray, patch_size: int = 1024, margin: int = 64) int
Get the status of the cell, describing the cell position.
A cell is either in the mid (0) or at one of the borders (1-8). Numbers are assigned clockwise, starting from top left: - top left = 1, top = 2, top right = 3, right = 4 - bottom right = 5, bottom = 6, bottom left = 7, left = 8 - Mid status is denoted by 0
- Parameters:
bbox – Bounding box of cell as array of shape (2, 2).
patch_size – Patch size. Defaults to 1024.
margin – Margin size. Defaults to 64.
- Returns:
Cell status code (0-8).
- Return type:
int
- loki2.inference.postprocessing_cupy.get_edge_patch(position: List[int], row: int, col: int) List[List[int]]
Get the edge patches of a cell located at the border.
- Parameters:
position – Position of the cell encoded as a list [top, right, down, left] where 1 indicates contact with that border.
row – Row position of the patch.
col – Column position of the patch.
- Returns:
List of edge patches, each patch encoded as [row, col].
- Return type:
List[List[int]]