loki2.inference.postprocessing_cupy

Postprocessing of Loki2 network output, tailored for the inference pipeline.

This module provides GPU-accelerated postprocessing using CuPy for efficient cell detection and segmentation from network predictions.

Module Contents

class loki2.inference.postprocessing_cupy.DetectionCellPostProcessorCupy(wsi: loki2.data.dataclass.wsi.WSI | loki2.data.dataclass.wsi.WSIMetadata, nr_types: int, resolution: float = 0.25, classifier: torch.nn.Module = None, binary: bool = False, gt: bool = False)
wsi
nr_types
resolution
classifier
gt
binary
check_network_output(predictions_: Dict[str, torch.Tensor]) None

Check if the network output is valid.

Parameters:

predictions

Network predictions dictionary with required keys: * nuclei_binary_map: Binary nucleus predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei.

Shape: (B, H, W, self.nr_types)

  • hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2)

Raises:

AssertionError – If predictions dictionary is invalid or missing required keys.

post_process_batch(predictions_: dict) Tuple[torch.Tensor, List[dict]]

Post process a batch of predictions and generate cell dictionary and instance predictions for each image in a list

Parameters:

predictions (dict) – Network predictions with tokens. Keys (required): * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (B, H, W, self.num_nuclei_classes,) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2)

Returns:

  • torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)

  • List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.

    For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”

Return type:

Tuple[torch.Tensor, List[dict]]

post_process_single_image(pred_map: cupy.ndarray) Tuple[numpy.ndarray, dict[int, dict]]

Process one single image and generate cell dictionary and instance predictions

Parameters:

pred_map (cp.ndarray) – Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4)

Returns:

_description_

Return type:

Tuple[np.ndarray, dict[int, dict]]

class loki2.inference.postprocessing_cupy.BatchPoolingActor(detection_cell_postprocessor: DetectionCellPostProcessorCupy, run_conf: Dict[str, Any])

Ray Actor for coordinating the postprocessing of batches.

The postprocessing is done in a separate process to avoid blocking the main process. The calculation is done with the help of the DetectionCellPostProcessorCupy class. This actor acts as a coordinator for the postprocessing of one batch and a wrapper for the DetectionCellPostProcessorCupy class.

detection_cell_postprocessor
run_conf
convert_batch_to_graph_nodes(predictions: Dict[str, torch.Tensor], metadata: List[Dict[str, Any]]) Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[torch.Tensor], List[torch.Tensor]]

Postprocess a batch of predictions and convert it to graph nodes.

Returns the complete graph nodes (cell dictionary), the detection nodes (cell detection dictionary), the cell tokens and the cell positions.

Parameters:
  • predictions

    Network predictions dictionary with required keys: * nuclei_binary_map: Binary nucleus predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei.

    Shape: (B, H, W, self.num_nuclei_classes)

    • hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2)

  • metadata – List of metadata dictionaries for each patch. Each dictionary needs to contain: * row: Row index of the patch * col: Column index of the patch Other keys are optional.

Returns:

  • Complete graph nodes (cell dictionaries)

  • Detection nodes (cell detection dictionaries)

  • Cell tokens

  • Cell positions

Return type:

Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[torch.Tensor], List[torch.Tensor]]

convert_patch_to_graph_nodes(patch_cell_dict: dict, patch_metadata: dict, patch_tokens: torch.Tensor) Tuple[List[dict], List[dict], List[torch.Tensor], List[torch.Tensor]]

Extract information from a single patch and convert it to graph nodes for a global view

Parameters:
  • patch_cell_dict (dict) – Dictionary containing the cell information. Each dictionary needs to contain the following keys: * bbox: Bounding box of the cell * centroid: Centroid of the cell * contour: Contour of the cell * type_prob: Probability of the cell type * type: Type of the cell

  • patch_metadata (dict) – Metadata dictionary for the patch. Each dictionary needs to contain the following keys: * row: Row index of the patch * col: Column index of the patch Other keys are optional but are stored in the graph nodes for later use

  • patch_tokens (torch.Tensor) – Tokens of the patch. Shape: (D, H, W)

Returns:

  • List[dict]: Complete graph nodes (cell dictionary) of the patch

  • List[dict]: Detection nodes (cell detection dictionary) of the patch

  • List[torch.Tensor]: Cell tokens of the patch

  • List[torch.Tensor]: Cell positions (centroid) of the patch

Return type:

Tuple[List[dict], List[dict], List[torch.Tensor], List[torch.Tensor]]

loki2.inference.postprocessing_cupy.get_cell_position(bbox: numpy.ndarray, patch_size: int = 1024) List[int]

Get cell position as a list indicating which borders the cell touches.

Entry is 1 if cell touches the border: [top, right, down, left]

Parameters:
  • bbox – Bounding box of cell as array of shape (2, 2) in (h, w) format.

  • patch_size – Patch size. Defaults to 1024.

Returns:

List with 4 integers [top, right, down, left] indicating

border contact (1) or not (0).

Return type:

List[int]

loki2.inference.postprocessing_cupy.get_cell_position_marging(bbox: numpy.ndarray, patch_size: int = 1024, margin: int = 64) int

Get the status of the cell, describing the cell position.

A cell is either in the mid (0) or at one of the borders (1-8). Numbers are assigned clockwise, starting from top left: - top left = 1, top = 2, top right = 3, right = 4 - bottom right = 5, bottom = 6, bottom left = 7, left = 8 - Mid status is denoted by 0

Parameters:
  • bbox – Bounding box of cell as array of shape (2, 2).

  • patch_size – Patch size. Defaults to 1024.

  • margin – Margin size. Defaults to 64.

Returns:

Cell status code (0-8).

Return type:

int

loki2.inference.postprocessing_cupy.get_edge_patch(position: List[int], row: int, col: int) List[List[int]]

Get the edge patches of a cell located at the border.

Parameters:
  • position – Position of the cell encoded as a list [top, right, down, left] where 1 indicates contact with that border.

  • row – Row position of the patch.

  • col – Column position of the patch.

Returns:

List of edge patches, each patch encoded as [row, col].

Return type:

List[List[int]]