loki2.mil.models.src.utils
Utility functions for MIL model evaluation and visualization.
This module provides functions for calculating metrics, evaluating models, generating attention heatmaps, plotting training curves, and computing cross-validation results.
Module Contents
- loki2.mil.models.src.utils.calculate_metrics(y_prob: torch.Tensor, labels: torch.Tensor, criterion: torch.nn.Module) Dict[str, float | int | numpy.ndarray]
Calculate unified metrics for model evaluation.
Computes loss, accuracy, AUROC, and predictions from model outputs and ground truth labels. Handles various tensor shapes automatically.
- Parameters:
y_prob – Model output probabilities of shape (batch_size, 1) or (batch_size,).
labels – Ground truth labels of shape (batch_size, 1) or (batch_size,).
criterion – Loss function (e.g., nn.BCELoss).
- Returns:
loss: Computed loss value (float).
accuracy: Accuracy score (float).
auroc: AUROC score (float, 0.5 if only one class present).
correct: Number of correct predictions (int).
total: Total number of samples (int).
predictions: Binary predictions as NumPy array.
probabilities: Probability values as NumPy array.
- Return type:
Dict containing
- loki2.mil.models.src.utils.evaluate_model(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, device: torch.device, description: str = 'Evaluation') Dict[str, float | int | List[str] | List[int] | List[float]]
Evaluate model on a dataset.
Computes metrics including loss, accuracy, AUROC, and collects predictions for all samples. Handles both logits and probability outputs automatically.
- Parameters:
model – PyTorch model to evaluate.
data_loader – DataLoader for the evaluation dataset.
device – PyTorch device (CPU or CUDA).
description – Description string for logging. Defaults to “Evaluation”.
- Returns:
patient_ids: List of patient ID strings.
true_labels: List of true labels (integers).
predicted_probs: List of predicted probabilities (floats).
predicted_labels: List of predicted labels (integers).
auroc: AUROC score (float).
accuracy: Accuracy score (float).
loss: Average loss (float).
correct: Number of correct predictions (int).
total: Total number of samples (int).
- Return type:
Dict containing
- loki2.mil.models.src.utils.generate_heatmaps(attention_data: pandas.DataFrame, pdf_filename: str | pathlib.Path, title_prefix: str, quantile_20: float, quantile_80: float, point_size: float) None
Generate attention heatmap PDF for all patients.
Creates scatter plots showing attention weights overlaid on cell positions for each patient, saved as a multi-page PDF.
- Parameters:
attention_data – DataFrame with columns: Patient_ID, X, Y, log10_Attention.
pdf_filename – Path to save the PDF file.
title_prefix – Prefix for plot titles.
quantile_20 – 20th percentile for normalization (lower bound).
quantile_80 – 80th percentile for normalization (upper bound).
point_size – Size of scatter plot points.
- loki2.mil.models.src.utils.generate_epoch_attention_analysis(model: torch.nn.Module, data_loader: torch.utils.data.DataLoader, train_patients: numpy.ndarray, model_folder: str | pathlib.Path, epoch_num: int, point_size: float, embedding_dim: int, M_dim: int, L_dim: int, attention_branches: int, dropout_rate: float) None
Generate attention analysis for a specific epoch.
Extracts attention weights from the model, processes them, and generates heatmap visualizations for train and test sets.
- Parameters:
model – Trained PyTorch MIL model.
data_loader – DataLoader for the full dataset.
train_patients – Array of training patient IDs.
model_folder – Directory to save attention analysis files.
epoch_num – Epoch number for file naming.
point_size – Size of scatter plot points.
embedding_dim – Embedding dimension (unused, kept for compatibility).
M_dim – M dimension (unused, kept for compatibility).
L_dim – L dimension (unused, kept for compatibility).
attention_branches – Number of attention branches (unused, kept for compatibility).
dropout_rate – Dropout rate (unused, kept for compatibility).
- loki2.mil.models.src.utils.downsample_data(dataframe: pandas.DataFrame, negative_ratio: float = 1.0, positive_negative_ratio: float = 1.0, max_patients: int | None = None, embedding_ratio: float = 1.0, random_seed: int = 27) pandas.DataFrame
Downsample data by patient and embedding counts.
Reduces the dataset size by sampling patients and their embeddings according to specified ratios, while maintaining class balance.
- Parameters:
dataframe – Input DataFrame with columns: Patient_ID, Patient_Label, and embedding features.
negative_ratio – Ratio of negative patients to keep. Defaults to 1.0.
positive_negative_ratio – Target ratio of positive to negative patients. Defaults to 1.0.
max_patients – Maximum number of patients to keep. If None, no limit. Defaults to None.
embedding_ratio – Ratio of embeddings to keep per patient. Defaults to 1.0.
random_seed – Random seed for reproducibility. Defaults to 27.
- Returns:
Downsampled DataFrame with the same structure as input.
- Return type:
pd.DataFrame
- loki2.mil.models.src.utils.compute_and_plot_overall_metrics(all_test_predictions: List[Dict[str, Any]], all_fold_results: List[Dict[str, Any]], patient_ids: numpy.ndarray, args: Any, model_folder: str | pathlib.Path, title_prefix: str | None = None) Dict[str, Any]
Compute overall 5-fold test metrics and plot ROC curve.
Aggregates predictions from all folds, computes overall metrics, generates ROC curve plot, and saves comprehensive cross-validation results.
- Parameters:
all_test_predictions – List of test prediction dictionaries from each fold, each containing: patient_ids, true_labels, predicted_probs, predicted_labels, fold.
all_fold_results – List of fold result dictionaries, each containing: fold, best_epoch, best_val_auroc, final_* metrics.
patient_ids – Array of all patient IDs in the dataset.
args – Arguments object containing n_folds, seed, etc.
model_folder – Directory path to save outputs.
title_prefix – Prefix for plot title (e.g., “cancer_type - signature”). Defaults to None.
- Returns:
cross_validation_settings: Dictionary with CV configuration.
fold_results: List of fold result dictionaries.
summary: Dictionary with mean and std of validation metrics.
5fold_overall_test_metrics: Dictionary with overall test metrics including AUROC, accuracy, precision, recall, F1, sensitivity, specificity, confusion matrix, and file paths.
- Return type:
Dict containing
- loki2.mil.models.src.utils.plot_training_curves(training_history: List[Dict[str, Any]], fold_folder: str | pathlib.Path, fold_num: int) None
Plot training curves for a single fold.
Generates three sets of plots: AUROC comparison, accuracy comparison, and loss comparison, each with individual and combined views.
- Parameters:
training_history – List of dictionaries, each containing epoch metrics: epoch, train_loss, train_accuracy, train_auroc, val_loss, val_accuracy, val_auroc, test_loss (optional), test_accuracy (optional), test_auroc (optional).
fold_folder – Directory to save the plot files.
fold_num – Fold number for plot titles.