loki2.mil.models.src.utils ========================== .. py:module:: loki2.mil.models.src.utils .. autoapi-nested-parse:: Utility functions for MIL model evaluation and visualization. This module provides functions for calculating metrics, evaluating models, generating attention heatmaps, plotting training curves, and computing cross-validation results. Module Contents --------------- .. py:function:: calculate_metrics(y_prob: torch.Tensor, labels: torch.Tensor, criterion: torch.nn.Module) -> Dict[str, Union[float, int, numpy.ndarray]] Calculate unified metrics for model evaluation. Computes loss, accuracy, AUROC, and predictions from model outputs and ground truth labels. Handles various tensor shapes automatically. :param y_prob: Model output probabilities of shape (batch_size, 1) or (batch_size,). :param labels: Ground truth labels of shape (batch_size, 1) or (batch_size,). :param criterion: Loss function (e.g., nn.BCELoss). :returns: - loss: Computed loss value (float). - accuracy: Accuracy score (float). - auroc: AUROC score (float, 0.5 if only one class present). - correct: Number of correct predictions (int). - total: Total number of samples (int). - predictions: Binary predictions as NumPy array. - probabilities: Probability values as NumPy array. :rtype: Dict containing .. py:function:: evaluate_model(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device, description: str = 'Evaluation') -> Dict[str, Union[float, int, List[str], List[int], List[float]]] Evaluate model on a dataset. Computes metrics including loss, accuracy, AUROC, and collects predictions for all samples. Handles both logits and probability outputs automatically. :param model: PyTorch model to evaluate. :param data_loader: DataLoader for the evaluation dataset. :param device: PyTorch device (CPU or CUDA). :param description: Description string for logging. Defaults to "Evaluation". :returns: - patient_ids: List of patient ID strings. - true_labels: List of true labels (integers). - predicted_probs: List of predicted probabilities (floats). - predicted_labels: List of predicted labels (integers). - auroc: AUROC score (float). - accuracy: Accuracy score (float). - loss: Average loss (float). - correct: Number of correct predictions (int). - total: Total number of samples (int). :rtype: Dict containing .. py:function:: generate_heatmaps(attention_data: pandas.DataFrame, pdf_filename: Union[str, pathlib.Path], title_prefix: str, quantile_20: float, quantile_80: float, point_size: float) -> None Generate attention heatmap PDF for all patients. Creates scatter plots showing attention weights overlaid on cell positions for each patient, saved as a multi-page PDF. :param attention_data: DataFrame with columns: Patient_ID, X, Y, log10_Attention. :param pdf_filename: Path to save the PDF file. :param title_prefix: Prefix for plot titles. :param quantile_20: 20th percentile for normalization (lower bound). :param quantile_80: 80th percentile for normalization (upper bound). :param point_size: Size of scatter plot points. .. py:function:: generate_epoch_attention_analysis(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, train_patients: numpy.ndarray, model_folder: Union[str, pathlib.Path], epoch_num: int, point_size: float, embedding_dim: int, M_dim: int, L_dim: int, attention_branches: int, dropout_rate: float) -> None Generate attention analysis for a specific epoch. Extracts attention weights from the model, processes them, and generates heatmap visualizations for train and test sets. :param model: Trained PyTorch MIL model. :param data_loader: DataLoader for the full dataset. :param train_patients: Array of training patient IDs. :param model_folder: Directory to save attention analysis files. :param epoch_num: Epoch number for file naming. :param point_size: Size of scatter plot points. :param embedding_dim: Embedding dimension (unused, kept for compatibility). :param M_dim: M dimension (unused, kept for compatibility). :param L_dim: L dimension (unused, kept for compatibility). :param attention_branches: Number of attention branches (unused, kept for compatibility). :param dropout_rate: Dropout rate (unused, kept for compatibility). .. py:function:: downsample_data(dataframe: pandas.DataFrame, negative_ratio: float = 1.0, positive_negative_ratio: float = 1.0, max_patients: Optional[int] = None, embedding_ratio: float = 1.0, random_seed: int = 27) -> pandas.DataFrame Downsample data by patient and embedding counts. Reduces the dataset size by sampling patients and their embeddings according to specified ratios, while maintaining class balance. :param dataframe: Input DataFrame with columns: Patient_ID, Patient_Label, and embedding features. :param negative_ratio: Ratio of negative patients to keep. Defaults to 1.0. :param positive_negative_ratio: Target ratio of positive to negative patients. Defaults to 1.0. :param max_patients: Maximum number of patients to keep. If None, no limit. Defaults to None. :param embedding_ratio: Ratio of embeddings to keep per patient. Defaults to 1.0. :param random_seed: Random seed for reproducibility. Defaults to 27. :returns: Downsampled DataFrame with the same structure as input. :rtype: pd.DataFrame .. py:function:: compute_and_plot_overall_metrics(all_test_predictions: List[Dict[str, Any]], all_fold_results: List[Dict[str, Any]], patient_ids: numpy.ndarray, args: Any, model_folder: Union[str, pathlib.Path], title_prefix: Optional[str] = None) -> Dict[str, Any] Compute overall 5-fold test metrics and plot ROC curve. Aggregates predictions from all folds, computes overall metrics, generates ROC curve plot, and saves comprehensive cross-validation results. :param all_test_predictions: List of test prediction dictionaries from each fold, each containing: patient_ids, true_labels, predicted_probs, predicted_labels, fold. :param all_fold_results: List of fold result dictionaries, each containing: fold, best_epoch, best_val_auroc, final_* metrics. :param patient_ids: Array of all patient IDs in the dataset. :param args: Arguments object containing n_folds, seed, etc. :param model_folder: Directory path to save outputs. :param title_prefix: Prefix for plot title (e.g., "cancer_type - signature"). Defaults to None. :returns: - cross_validation_settings: Dictionary with CV configuration. - fold_results: List of fold result dictionaries. - summary: Dictionary with mean and std of validation metrics. - 5fold_overall_test_metrics: Dictionary with overall test metrics including AUROC, accuracy, precision, recall, F1, sensitivity, specificity, confusion matrix, and file paths. :rtype: Dict containing .. py:function:: plot_training_curves(training_history: List[Dict[str, Any]], fold_folder: Union[str, pathlib.Path], fold_num: int) -> None Plot training curves for a single fold. Generates three sets of plots: AUROC comparison, accuracy comparison, and loss comparison, each with individual and combined views. :param training_history: List of dictionaries, each containing epoch metrics: epoch, train_loss, train_accuracy, train_auroc, val_loss, val_accuracy, val_auroc, test_loss (optional), test_accuracy (optional), test_auroc (optional). :param fold_folder: Directory to save the plot files. :param fold_num: Fold number for plot titles.