loki2.mil.models.src.train
Training module for Multiple Instance Learning (MIL) model.
This module provides functions for training MIL models with cross-validation, early stopping, and attention analysis.
Module Contents
- loki2.mil.models.src.train.train_single_fold(fold: int, train_patients: numpy.ndarray, val_patients: numpy.ndarray, data_final: pandas.DataFrame, args: Any, model_folder: str | pathlib.Path, device: torch.device, enable_fresh_initialization: bool = True, return_test_predictions: bool = True) Dict[str, Any] | Tuple[Dict[str, Any], Dict[str, Any]]
Train a single fold of cross-validation.
Performs training with early stopping, model checkpointing, and attention analysis. Supports weighted sampling for class imbalance and various learning rate scheduling strategies.
- Parameters:
fold – Fold number, from 0 to args.n_folds - 1.
train_patients – Array of training set patient IDs.
val_patients – Array of validation/test set patient IDs.
data_final – Full data DataFrame with columns: Patient_ID, Patient_Label, cell_position, and embedding features.
args – Arguments object containing hyperparameters and settings.
model_folder – Directory path to save model checkpoints and results.
device – PyTorch device (CPU or CUDA).
enable_fresh_initialization – Whether to enable fresh model weight initialization for each fold. Defaults to True.
return_test_predictions – Whether to return test predictions for 5-fold overall evaluation. Defaults to True.
- Returns:
- Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
- fold_results: Dictionary with fold training results including:
fold: Fold number.
best_epoch: Best epoch number.
best_val_auroc: Best validation AUROC.
final_train_accuracy: Final training accuracy.
final_train_auroc: Final training AUROC.
final_val_accuracy: Final validation accuracy.
final_val_auroc: Final validation AUROC.
final_test_accuracy: Final test accuracy.
final_test_auroc: Final test AUROC.
- test_predictions: Dictionary with test set predictions:
patient_ids: List of patient IDs.
true_labels: List of true labels.
predicted_probs: List of predicted probabilities.
predicted_labels: List of predicted labels.
- If return_test_predictions is False:
Dict[str, Any]: Only the fold_results dictionary.
- Return type:
If return_test_predictions is True