loki2.inference.postprocessing_cupy =================================== .. py:module:: loki2.inference.postprocessing_cupy .. autoapi-nested-parse:: Postprocessing of Loki2 network output, tailored for the inference pipeline. This module provides GPU-accelerated postprocessing using CuPy for efficient cell detection and segmentation from network predictions. Module Contents --------------- .. py:class:: DetectionCellPostProcessorCupy(wsi: Union[loki2.data.dataclass.wsi.WSI, loki2.data.dataclass.wsi.WSIMetadata], nr_types: int, resolution: float = 0.25, classifier: torch.nn.Module = None, binary: bool = False, gt: bool = False) .. py:attribute:: wsi .. py:attribute:: nr_types .. py:attribute:: resolution .. py:attribute:: classifier .. py:attribute:: gt .. py:attribute:: binary .. py:method:: check_network_output(predictions_: Dict[str, torch.Tensor]) -> None Check if the network output is valid. :param predictions_: Network predictions dictionary with required keys: * nuclei_binary_map: Binary nucleus predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (B, H, W, self.nr_types) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2) :raises AssertionError: If predictions dictionary is invalid or missing required keys. .. py:method:: post_process_batch(predictions_: dict) -> Tuple[torch.Tensor, List[dict]] Post process a batch of predictions and generate cell dictionary and instance predictions for each image in a list :param predictions_: Network predictions with tokens. Keys (required): * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (B, H, W, self.num_nuclei_classes,) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2) :type predictions_: dict :returns: * torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W) * List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus. For each nucleus, the following information are returned: "bbox", "centroid", "contour", "type_prob", "type" :rtype: Tuple[torch.Tensor, List[dict]] .. py:method:: post_process_single_image(pred_map: cupy.ndarray) -> Tuple[numpy.ndarray, dict[int, dict]] Process one single image and generate cell dictionary and instance predictions :param pred_map: Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4) :type pred_map: cp.ndarray :returns: _description_ :rtype: Tuple[np.ndarray, dict[int, dict]] .. py:class:: BatchPoolingActor(detection_cell_postprocessor: DetectionCellPostProcessorCupy, run_conf: Dict[str, Any]) Ray Actor for coordinating the postprocessing of batches. The postprocessing is done in a separate process to avoid blocking the main process. The calculation is done with the help of the `DetectionCellPostProcessorCupy` class. This actor acts as a coordinator for the postprocessing of one batch and a wrapper for the `DetectionCellPostProcessorCupy` class. .. py:attribute:: detection_cell_postprocessor .. py:attribute:: run_conf .. py:method:: convert_batch_to_graph_nodes(predictions: Dict[str, torch.Tensor], metadata: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[torch.Tensor], List[torch.Tensor]] Postprocess a batch of predictions and convert it to graph nodes. Returns the complete graph nodes (cell dictionary), the detection nodes (cell detection dictionary), the cell tokens and the cell positions. :param predictions: Network predictions dictionary with required keys: * nuclei_binary_map: Binary nucleus predictions. Shape: (B, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (B, H, W, self.num_nuclei_classes) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, H, W, 2) :param metadata: List of metadata dictionaries for each patch. Each dictionary needs to contain: * row: Row index of the patch * col: Column index of the patch Other keys are optional. :returns: - Complete graph nodes (cell dictionaries) - Detection nodes (cell detection dictionaries) - Cell tokens - Cell positions :rtype: Tuple[List[Dict[str, Any]], List[Dict[str, Any]], List[torch.Tensor], List[torch.Tensor]] .. py:method:: convert_patch_to_graph_nodes(patch_cell_dict: dict, patch_metadata: dict, patch_tokens: torch.Tensor) -> Tuple[List[dict], List[dict], List[torch.Tensor], List[torch.Tensor]] Extract information from a single patch and convert it to graph nodes for a global view :param patch_cell_dict: Dictionary containing the cell information. Each dictionary needs to contain the following keys: * bbox: Bounding box of the cell * centroid: Centroid of the cell * contour: Contour of the cell * type_prob: Probability of the cell type * type: Type of the cell :type patch_cell_dict: dict :param patch_metadata: Metadata dictionary for the patch. Each dictionary needs to contain the following keys: * row: Row index of the patch * col: Column index of the patch Other keys are optional but are stored in the graph nodes for later use :type patch_metadata: dict :param patch_tokens: Tokens of the patch. Shape: (D, H, W) :type patch_tokens: torch.Tensor :returns: * List[dict]: Complete graph nodes (cell dictionary) of the patch * List[dict]: Detection nodes (cell detection dictionary) of the patch * List[torch.Tensor]: Cell tokens of the patch * List[torch.Tensor]: Cell positions (centroid) of the patch :rtype: Tuple[List[dict], List[dict], List[torch.Tensor], List[torch.Tensor]] .. py:function:: get_cell_position(bbox: numpy.ndarray, patch_size: int = 1024) -> List[int] Get cell position as a list indicating which borders the cell touches. Entry is 1 if cell touches the border: [top, right, down, left] :param bbox: Bounding box of cell as array of shape (2, 2) in (h, w) format. :param patch_size: Patch size. Defaults to 1024. :returns: List with 4 integers [top, right, down, left] indicating border contact (1) or not (0). :rtype: List[int] .. py:function:: get_cell_position_marging(bbox: numpy.ndarray, patch_size: int = 1024, margin: int = 64) -> int Get the status of the cell, describing the cell position. A cell is either in the mid (0) or at one of the borders (1-8). Numbers are assigned clockwise, starting from top left: - top left = 1, top = 2, top right = 3, right = 4 - bottom right = 5, bottom = 6, bottom left = 7, left = 8 - Mid status is denoted by 0 :param bbox: Bounding box of cell as array of shape (2, 2). :param patch_size: Patch size. Defaults to 1024. :param margin: Margin size. Defaults to 64. :returns: Cell status code (0-8). :rtype: int .. py:function:: get_edge_patch(position: List[int], row: int, col: int) -> List[List[int]] Get the edge patches of a cell located at the border. :param position: Position of the cell encoded as a list [top, right, down, left] where 1 indicates contact with that border. :param row: Row position of the patch. :param col: Column position of the patch. :returns: List of edge patches, each patch encoded as [row, col]. :rtype: List[List[int]]