Loki2 Morphology-to-Transcriptome Retrieval

This notebook demonstrates how to retrieve transcriptomic profiles from Loki2 cell morphological features using cross-modal contrastive learning.

Data Requirements

The example data is stored in the directory ../data/morph_retrieve, which can be donwloaded from Google Drive.

You will need:

  • Paired morphology-transcriptome data for training

  • Single-cell RNA-seq reference atlas

  • H&E image to retrieve transcriptome data

  • OmiCLIP model checkpoint for transcriptome encoding

Prepare Training Data

Extract transcriptomic embeddings using OmiCLIP and pair them with morphological embeddings from Loki2. This creates paired training data that links morphological features to transcriptomic profiles. The data is organized into shards for efficient training with WebDataset format.

Please follow Loki Notebooks to creat the conda enveriment loki_env for contrastive learning.
Please put the OmiCLIP checkpoint omiclip.pt to ../src/loki2/.

conda activate loki_env

DATA_PATH="../data/morph_retrieve/P1CRC_VISIUMHD_LOKI2_mask.h5ad"
OUTPUT="../outputs/morph_retrieve/output/P1CRC_cell_trans_emb_raw.pt"

python ../src/loki2/encode_trans.py "$DATA_PATH" \
    --output "$OUTPUT" \
    --batch-size 1024 \
    --num-threads 32 \
    --device cuda

python ../src/loki2/cl/prepare_training.py \
    --dataset-name P1CRC \
    --trans-path "$OUTPUT" \
    --morph-path ../data/morph_retrieve/P1CRC_cell_morph_emb.pt \
    --output-dir ../outputs/morph_retrieve/output/P1CRC_train \
    --shard-size 10000

Training

Train a neural network projection model using contrastive learning to align morphological and transcriptomic embedding spaces. The model learns to project embeddings from both modalities into a shared space where similar cells from different modalities are close together. This enables cross-modal retrieval between morphology and transcriptome.

conda activate loki_env

DATASET="${DATASET:-P1CRC}"
RUN_NAME="${RUN_NAME:-${DATASET}_wds_vanilla}"
RUN_DIR="../outputs/morph_retrieve/output/runs/${RUN_NAME}"

TRAIN_DIR="../outputs/morph_retrieve/output/P1CRC_train/${DATASET}/train"
VAL_DIR="../outputs/morph_retrieve/output/P1CRC_train/${DATASET}/val"
TRAIN_META="${TRAIN_DIR}/manifest_train.csv"
VAL_META="${VAL_DIR}/manifest_val.csv"

if [[ ! -f "${TRAIN_META}" ]]; then
  echo "Missing train manifest: ${TRAIN_META}" >&2
  exit 1
fi

if [[ ! -f "${VAL_META}" ]]; then
  echo "Missing validation manifest: ${VAL_META}" >&2
  exit 1
fi

mapfile -t TRAIN_SHARDS < <(find "${TRAIN_DIR}" -maxdepth 1 -type f -name 'shard-*.tar' | sort)
mapfile -t VAL_SHARDS < <(find "${VAL_DIR}" -maxdepth 1 -type f -name 'shard-*.tar' | sort)

if [[ ${#TRAIN_SHARDS[@]} -eq 0 ]]; then
  echo "No training shards found in ${TRAIN_DIR}" >&2
  exit 1
fi

if [[ ${#VAL_SHARDS[@]} -eq 0 ]]; then
  echo "No validation shards found in ${VAL_DIR}" >&2
  exit 1
fi

TRAIN_SHARD_LIST=$(IFS=, ; echo "${TRAIN_SHARDS[*]}")
VAL_SHARD_LIST=$(IFS=, ; echo "${VAL_SHARDS[*]}")

mkdir -p "${RUN_DIR}"

python ../src/loki2/cl/train_projection_wds.py \
  --train-shards "${TRAIN_SHARD_LIST}" \
  --train-meta "${TRAIN_META}" \
  --val-shards "${VAL_SHARD_LIST}" \
  --val-meta "${VAL_META}" \
  --num-layers 1 \
  --batch-size 4096 \
  --epochs 20 \
  --lr 5e-4 \
  --device cuda \
  --amp \
  --save-every 1 \
  --output-dir "${RUN_DIR}" \
  --log-file train.log

Retrieve from scRNA Data

Encode single-cell RNA-seq reference data and project both morphological and transcriptomic embeddings into the shared space. Then use nearest neighbor search to retrieve matching transcriptomic profiles for spatial cells based on their morphological features. This enables cell type transfer from scRNA-seq to spatial data and gene expression prediction from morphology.

conda activate loki_env

sam="CRC_sc"
DATA_PATH="../data/morph_retrieve/CRC_sc.h5ad"
OUTPUT="../outputs/morph_retrieve/output/${sam}_trans.pt"

python ../src/loki2/encode_trans.py "$DATA_PATH" \
    --output "$OUTPUT" \
    --batch-size 1024 \
    --num-threads 32 \
    --device cuda
conda deactivate
conda activate loki2_env

CHECKPOINT_DIR="../outputs/morph_retrieve/output/runs/P1CRC_wds_vanilla"

# Set the epoch to use for projection (default: 20)
EPOCH=${EPOCH:-20}
CKPT_PATH="${CHECKPOINT_DIR}/projection_cl_epoch${EPOCH}.pt"

if [[ ! -f "${CKPT_PATH}" ]]; then
  echo "Checkpoint not found: ${CKPT_PATH}" >&2
  exit 1
fi

declare -A SAMPLE_MAP=(
  ["Cancer_P2"]="P2CRC"
)

for label in Cancer_P2; do
  dataset="${SAMPLE_MAP[$label]}"
  morph_path="../data/morph_retrieve/${SAMPLE_MAP[$label]}_cell_morph_emb.pt"
  output_dir="../outputs/morph_retrieve/output/data_projection/${dataset}"

  if [[ ! -f "${morph_path}" ]]; then
    echo "Skipping ${label}: missing morphology embeddings at ${morph_path}" >&2
    continue
  fi

  echo "Projecting ${label} (dataset: ${dataset})"
  mkdir -p "${output_dir}"

  python ../src/loki2/cl/project_raw_embeddings.py \
    --checkpoint "${CKPT_PATH}" \
    --morph-path "${morph_path}" \
    --modality morph \
    --batch-size 4096 \
    --normalized \
    --tag "${label}_epoch${EPOCH}" \
    --output-dir "${output_dir}"
done

TRANS_PATH="../outputs/morph_retrieve/output/CRC_sc_trans.pt"
OUTPUT_DIR="../outputs/morph_retrieve/output/data_projection/sc"

if [[ ! -f "${TRANS_PATH}" ]]; then
  echo "Transcription embeddings missing: ${TRANS_PATH}" >&2
  exit 1
fi

mkdir -p "${OUTPUT_DIR}"

python ../src/loki2/cl/project_raw_embeddings.py \
  --checkpoint "${CKPT_PATH}" \
  --trans-path "${TRANS_PATH}" \
  --modality trans \
  --batch-size 4096 \
  --normalized \
  --tag "epoch${EPOCH}" \
  --output-dir "${OUTPUT_DIR}"

declare -A SAMPLE_MAP=(
  ["Cancer_P2"]="P2CRC"
)

for label in "${!SAMPLE_MAP[@]}"; do
    dataset="${SAMPLE_MAP[$label]}"
    output_dir=${3:-"../outputs/morph_retrieve/output/result_centroid/retrieve_epoch${EPOCH}"}
    mkdir -p ${output_dir}
    echo "Processing sample: ${dataset}, output directory: ${output_dir}"
    python ../src/loki2/retrieve_from_sc.py ${dataset} ${label} ${EPOCH} ${output_dir}
done
import numpy as np
import pandas as pd
import scanpy as sc
import torch
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

import loki2.retrieve
import loki2.plot

Load Retrieval Results

Load the projected embeddings and retrieval results. The retrieval results contain indices of matching cells from the single-cell reference atlas and their similarity scores, which can be used to predict cell types and gene expression patterns.

sample_name = "P2CRC"
proj_path = "../outputs/morph_retrieve/output/data_projection/P2CRC/morph_proj_Cancer_P2_epoch20.npz"
result_dir = "../outputs/morph_retrieve/output/result_centroid/retrieve_epoch20"
celltype_path = f"{result_dir}/cell_type_assignments_{sample_name}.csv"
result_path = f"{result_dir}/retr_from_atlas_{sample_name}.pt"
data_morph = np.load(proj_path, allow_pickle=True)
list(data_morph.keys())
['row_index', 'cell_id', 'positions', 'morph_proj', 'morph_proj_norm']
celltype_df = pd.read_csv(celltype_path, index_col=0)
result = torch.load(result_path)
indices = result['indices']
values = result['scores']
sc_ad = sc.read_h5ad("../data/morph_retrieve/CRC_sc.h5ad")
sc_ad.obsm['X_umap'] = sc_ad.obs[['UMAP1', 'UMAP2']].values
sc_ad
AnnData object with n_obs × n_vars = 260506 × 18082
    obs: 'Patient', 'BC', 'QCFilter', 'Level1', 'Level2', 'UMAP1', 'UMAP2'
    var: 'gene_ids', 'feature_types', 'genome'
    obsm: 'X_umap'
source_labels = sc_ad.obs['Level1']
level1_retrieved_top1 = source_labels.iloc[indices[:, 0]]
preds, _ = loki2.retrieve.knn_majority_vote(indices, source_labels)
preds_w, counts = loki2.retrieve.knn_majority_vote(indices, source_labels, scores=values)
preds_t, _ = loki2.retrieve.knn_majority_vote(indices, source_labels, scores=values, temperature=0.07)

Plot scRNA Data

Visualize the single-cell RNA-seq reference atlas in UMAP space, colored by cell type annotations. This provides context for understanding the reference data used for retrieval.

# ensure the category order matches your desired mapping
sc_ad.obs['Level1'] = sc_ad.obs['Level1'].astype('category')
# apply the colors to the categories
sc_ad.uns['Level1_colors'] = [loki2.plot.SC_COLOR_DICT[c] for c in sc_ad.obs['Level1'].cat.categories]
# plot
sc.pl.umap(sc_ad, color='Level1', frameon=False)
../_images/70be9403704416fa5868a7ce67a8df8be28069c989ca71651b2254547a5fdb35.png

Plot Loki2 Retrieval Results

Visualize the retrieved cell types and gene expression patterns mapped onto the spatial tissue coordinates.

plt.figure(figsize=(6, 6), dpi=150)
plt.scatter(data_morph['positions'][:, 0], data_morph['positions'][:, 1], 
            c=loki2.plot.labels_to_hex(preds_t, loki2.plot.SC_COLOR_DICT), 
            s=0.2, edgecolors='none', linewidths=0)
plt.gca().invert_yaxis()
plt.gca().set_aspect('equal')
plt.title("Loki2 Retrieved Cell Types")
plt.axis("off")

# -------- Legend --------
unique_classes = np.unique(preds_t)

handles = []
for cls in unique_classes:
    color = loki2.plot.SC_COLOR_DICT.get(cls, "#000000")  # fallback color if missing
    label = cls
    patch = mpatches.Patch(color=color, label=label)
    handles.append(patch)

plt.legend(
    handles=handles,
    loc="upper right",
    fontsize=6,
    frameon=False,
    bbox_to_anchor=(1.25, 1.0)   # nudge legend outside the plot if needed
)

plt.show()
../_images/8110dfc1b1f135ba07e0aa84c9289d689db9b56a69631a215760b78db0f42cfd.png
gene = 'CEACAM6'
loki2.plot.plot_gene(sc_ad, indices, data_morph, gene)
../_images/09055e48f79b71e13ae0fb9a3ec94197f4350b23fd296f8b78bc4b4d5df0f909.png
gene = 'COL3A1'
loki2.plot.plot_gene(sc_ad, indices, data_morph, gene)
../_images/9474c7093c26deeb7d79dc161ef5c494b34a70453c43cafb539abf5395cf8ebe.png
gene = 'PIGR'
loki2.plot.plot_gene(sc_ad, indices, data_morph, gene)
../_images/3f4922ca708d26831a5db0b9b4c6047ebef16786e40088987202ecd597a2651c.png