Loki2 Cell-level Multiple Instance Learning

This notebook demonstrates how to train the multiple instance learning (MIL) model on Loki2 cell embeddings for whole slide image classification, and visualize attention weights from trained MIL models.

MIL Model

Prepare training data:

  • Prepare your data as prampt files by code in ../src/mil/data/data_prepare/prampt_generate for stage and metastasis task.

  • Downsample cell embeddings by ../src/mil/data/data_prepare/prampt_downsample.py.

Run MIL to generate attention weights:

  • Run metastasis task by ../src/mil/models/TCGA-Loki_MILCell_blca_metastasis_5fold.py.

  • Run stage task by ../src/mil/models/TCGA-Loki_MILCell_blca_stage_5fold.py.

Plot Attention Heatmaps on WSI

Generate spatial heatmaps showing model attention across whole slide images. MIL models use attention mechanisms to identify which cells/regions are most important for classification decisions. This interpretability feature helps understand what morphological features drive predictions for tasks like cancer staging and metastasis prediction.

The example data is stored in the directory ../data/mil, which can be donwloaded from Google Drive.

You will need:

  • Attention weight files from MIL models

  • Whole slide images for visualization

import os
from pathlib import Path

import loki2.plot
# Parameters
max_dim = 4096  # Maximum side length for thumbnail
point_size = 3  # Scatter point size
alpha = 0.6  # Transparency
cmap = "rainbow"  # Colormap

TCGA BLCA Stage Subtype MIL Results

Visualize attention weights from MIL models trained for cancer staging classification. This helps identify morphologically distinct regions that are predictive of tumor stage.

# Configure SVS file paths and attention file paths
slide_dict_stage = {
    "TCGA-2F-A9KW": "../data/mil/TCGA-2F-A9KW-01Z-00-DX1.CECFDA2E-2CE7-4115-B4E6-A3D75B130232.svs",
    "TCGA-4Z-AA80": "../data/mil/TCGA-4Z-AA80-01Z-00-DX1.303549D2-42A5-46C4-AD9D-D72337B416E5.svs",
}

attention_weights_stage = {
    "TCGA-2F-A9KW": "../data/mil/stage_attention.txt",
    "TCGA-4Z-AA80": "../data/mil/stage_attention.txt",
}
out_dir = Path("../outputs/mil/output/stage/")
os.makedirs(out_dir, exist_ok=True)

save_images = True  # If True → save PNG; if False → only display

for case_id, svs_path in slide_dict_stage.items():
    att_txt = attention_weights_stage.get(case_id)
    if att_txt is None:
        print(f"[SKIP] No attention file for {case_id}")
        continue

    out_png = out_dir / f"attention_heatmap_{case_id}.png" if save_images else None
    print(f"[PLOT] {case_id}")
    
    loki2.plot.plot_attention_on_slide(
        svs_path=svs_path,
        att_txt=att_txt,
        case_prefix=case_id,
        save_path=out_png,
        max_dim=max_dim,
        point_size=point_size,
        alpha=alpha,
        cmap=cmap,
    )

print("\n✅ All attention heatmaps processed!")
[PLOT] TCGA-2F-A9KW
[SAVED] ../outputs/mil/output/stage/attention_heatmap_TCGA-2F-A9KW.png
../_images/cdbda731f9a107487b7ab60a991a613a24f886d655db497a2651581318f27cbb.png
[PLOT] TCGA-4Z-AA80
[SAVED] ../outputs/mil/output/stage/attention_heatmap_TCGA-4Z-AA80.png
../_images/93713245e2b108e63f8efb9ed6fae08e299acaad2d1806460ae16939c9b50382.png
✅ All attention heatmaps processed!

TCGA BLCA Metastasis Subtype MIL Results

Visualize attention weights from MIL models trained for metastasis prediction. This highlights cells and regions that are most predictive of metastasis.

# Configure SVS file paths and attention file paths
slide_dict_stage_meta = {
    "TCGA-DK-A3X2": "../data/mil/TCGA-DK-A3X2-01Z-00-DX1.CB507611-E3AE-43A2-B5F8-EEAE59423E2E.svs",
    "TCGA-FD-A6TF": "../data/mil/TCGA-FD-A6TF-01Z-00-DX1.15B2C3E0-A0D7-4879-82B1-6C9AB09AF8E2.svs",
}

attention_weights_meta = {
    "TCGA-DK-A3X2": "../data/mil/metastasis_attention.txt",
    "TCGA-FD-A6TF": "../data/mil/metastasis_attention.txt",
}
out_dir = Path("../outputs/mil/output/metastasis/")
os.makedirs(out_dir, exist_ok=True)

save_images = True  # If True → save PNG; if False → only display

for case_id, svs_path in slide_dict_stage_meta.items():
    att_txt = attention_weights_meta.get(case_id)
    if att_txt is None:
        print(f"[SKIP] No attention file for {case_id}")
        continue

    out_png = out_dir / f"attention_heatmap_{case_id}.png" if save_images else None
    print(f"[PLOT] {case_id}")
    
    loki2.plot.plot_attention_on_slide(
        svs_path=svs_path,
        att_txt=att_txt,
        case_prefix=case_id,
        save_path=out_png,
        max_dim=max_dim,
        point_size=point_size,
        alpha=alpha,
        cmap=cmap,
    )

print("\n✅ All attention heatmaps processed!")
[PLOT] TCGA-DK-A3X2
[SAVED] ../outputs/mil/output/metastasis/attention_heatmap_TCGA-DK-A3X2.png
../_images/fa3aedb5e03871d35a6dccd8ea81243eb98d20d0fbb7d2f77b002bd1b20f3a0d.png
[PLOT] TCGA-FD-A6TF
[SAVED] ../outputs/mil/output/metastasis/attention_heatmap_TCGA-FD-A6TF.png
../_images/3fef0c280de3ba063f225ff2827e5ee46a3ee45ce29d269b14d1afedd940ea53.png
✅ All attention heatmaps processed!