loki2.mil.models.src.train

Training module for Multiple Instance Learning (MIL) model.

This module provides functions for training MIL models with cross-validation, early stopping, and attention analysis.

Module Contents

loki2.mil.models.src.train.train_single_fold(fold: int, train_patients: numpy.ndarray, val_patients: numpy.ndarray, data_final: pandas.DataFrame, args: Any, model_folder: str | pathlib.Path, device: torch.device, enable_fresh_initialization: bool = True, return_test_predictions: bool = True) Dict[str, Any] | Tuple[Dict[str, Any], Dict[str, Any]]

Train a single fold of cross-validation.

Performs training with early stopping, model checkpointing, and attention analysis. Supports weighted sampling for class imbalance and various learning rate scheduling strategies.

Parameters:
  • fold – Fold number, from 0 to args.n_folds - 1.

  • train_patients – Array of training set patient IDs.

  • val_patients – Array of validation/test set patient IDs.

  • data_final – Full data DataFrame with columns: Patient_ID, Patient_Label, cell_position, and embedding features.

  • args – Arguments object containing hyperparameters and settings.

  • model_folder – Directory path to save model checkpoints and results.

  • device – PyTorch device (CPU or CUDA).

  • enable_fresh_initialization – Whether to enable fresh model weight initialization for each fold. Defaults to True.

  • return_test_predictions – Whether to return test predictions for 5-fold overall evaluation. Defaults to True.

Returns:

Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing:
  • fold_results: Dictionary with fold training results including:
    • fold: Fold number.

    • best_epoch: Best epoch number.

    • best_val_auroc: Best validation AUROC.

    • final_train_accuracy: Final training accuracy.

    • final_train_auroc: Final training AUROC.

    • final_val_accuracy: Final validation accuracy.

    • final_val_auroc: Final validation AUROC.

    • final_test_accuracy: Final test accuracy.

    • final_test_auroc: Final test AUROC.

  • test_predictions: Dictionary with test set predictions:
    • patient_ids: List of patient IDs.

    • true_labels: List of true labels.

    • predicted_probs: List of predicted probabilities.

    • predicted_labels: List of predicted labels.

If return_test_predictions is False:

Dict[str, Any]: Only the fold_results dictionary.

Return type:

If return_test_predictions is True