loki2.mil.models.src.train ========================== .. py:module:: loki2.mil.models.src.train .. autoapi-nested-parse:: Training module for Multiple Instance Learning (MIL) model. This module provides functions for training MIL models with cross-validation, early stopping, and attention analysis. Module Contents --------------- .. py:function:: train_single_fold(fold: int, train_patients: numpy.ndarray, val_patients: numpy.ndarray, data_final: pandas.DataFrame, args: Any, model_folder: Union[str, pathlib.Path], device: torch.device, enable_fresh_initialization: bool = True, return_test_predictions: bool = True) -> Union[Dict[str, Any], Tuple[Dict[str, Any], Dict[str, Any]]] Train a single fold of cross-validation. Performs training with early stopping, model checkpointing, and attention analysis. Supports weighted sampling for class imbalance and various learning rate scheduling strategies. :param fold: Fold number, from 0 to args.n_folds - 1. :param train_patients: Array of training set patient IDs. :param val_patients: Array of validation/test set patient IDs. :param data_final: Full data DataFrame with columns: Patient_ID, Patient_Label, cell_position, and embedding features. :param args: Arguments object containing hyperparameters and settings. :param model_folder: Directory path to save model checkpoints and results. :param device: PyTorch device (CPU or CUDA). :param enable_fresh_initialization: Whether to enable fresh model weight initialization for each fold. Defaults to True. :param return_test_predictions: Whether to return test predictions for 5-fold overall evaluation. Defaults to True. :returns: Tuple[Dict[str, Any], Dict[str, Any]]: A tuple containing: - fold_results: Dictionary with fold training results including: - fold: Fold number. - best_epoch: Best epoch number. - best_val_auroc: Best validation AUROC. - final_train_accuracy: Final training accuracy. - final_train_auroc: Final training AUROC. - final_val_accuracy: Final validation accuracy. - final_val_auroc: Final validation AUROC. - final_test_accuracy: Final test accuracy. - final_test_auroc: Final test AUROC. - test_predictions: Dictionary with test set predictions: - patient_ids: List of patient IDs. - true_labels: List of true labels. - predicted_probs: List of predicted probabilities. - predicted_labels: List of predicted labels. If return_test_predictions is False: Dict[str, Any]: Only the fold_results dictionary. :rtype: If return_test_predictions is True