Basic Usage of OmiCLIP Model

This notebook demonstrates the basic usage of OmiCLIP model, including loading model, encoding text and image, calculating similairty between image embeddings and text embeddings, and examples of preprocess ST data, scRNA-seq data, bulk RNA-seq data, and whole image. It takes about 30 mins to run this notebook using cpu on MacBook Pro.

[1]:
import os
import pandas as pd
import numpy as np
import scanpy as sc
import anndata
from PIL import Image

import loki.utils
import loki.preprocess
/opt/anaconda3/envs/loki_env/lib/python3.9/site-packages/timm/models/layers/__init__.py:48: FutureWarning: Importing from timm.models.layers is deprecated, please import via timm.layers
  warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.layers", FutureWarning)

The sample data are stored in the directory data/basic_usage, which can be donwloaded from Google Drive link.

[2]:
data_dir = './data/basic_usage/'

Load OmiCLIP Model

The pretrained weights are avaliable on Hugging Face.

[3]:
model_path = os.path.join(data_dir, 'checkpoint.pt')
device = 'cpu'
[4]:
model, preprocess, tokenizer = loki.utils.load_model(model_path, device)
[5]:
model.eval()
[5]:
CoCa(
  (text): TextTransformer(
    (token_embedding): Embedding(49408, 768)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-11): 12 x ResidualAttentionBlock(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=768, out_features=3072, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (visual): VisionTransformer(
    (conv1): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
    (patch_dropout): Identity()
    (ln_pre): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (transformer): Transformer(
      (resblocks): ModuleList(
        (0-23): 24 x ResidualAttentionBlock(
          (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
          )
          (ls_1): Identity()
          (ln_2): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (mlp): Sequential(
            (c_fc): Linear(in_features=1024, out_features=4096, bias=True)
            (gelu): GELU(approximate='none')
            (c_proj): Linear(in_features=4096, out_features=1024, bias=True)
          )
          (ls_2): Identity()
        )
      )
    )
    (attn_pool): AttentionalPooler(
      (attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
      )
      (ln_q): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln_k): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (ln_post): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (text_decoder): MultimodalTransformer(
    (resblocks): ModuleList(
      (0-11): 12 x ResidualAttentionBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ls_1): Identity()
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
        (ls_2): Identity()
      )
    )
    (cross_attn): ModuleList(
      (0-11): 12 x ResidualAttentionBlock(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (ls_1): Identity()
        (ln_1_kv): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): Sequential(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
        )
        (ls_2): Identity()
      )
    )
    (ln_final): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

Encode Image

Use OmiCLIP to encode image.

[6]:
image_path = os.path.join(data_dir, 'demo_data', 'TUM-TCGA-TLSHWGSQ.tif')
image = Image.open(image_path)
image
[6]:
../_images/notebooks_basic_usage_9_0.png
[7]:
image_embeddings = loki.utils.encode_images(model, preprocess, [image_path], device)
image_embeddings.shape
[7]:
torch.Size([1, 768])

Encode Text

Use OmiCLIP to encode text.

[8]:
text = ['TP53 EPCAM KRAS EGFR DEFA5 DEFA6 CEACAM5 CEA KRT18 KRT8 KRT19 CDH17 CK20 MYO6 TP53BP2 PLA2G2A CLDN7 TJP1 PKP3 DSP']
text_embeddings = loki.utils.encode_texts(model, tokenizer, text, device)
text_embeddings.shape
[8]:
torch.Size([1, 768])

Calculate Similarity

Calculate similairty between image embeddings and text embeddings.

[9]:
dot_similarity = image_embeddings @ text_embeddings.T
[10]:
dot_similarity
[10]:
tensor([[0.3926]])

Examples of preprocessing ST data, scRNA-seq data, bulk RNA-seq data, and whole image

Examples of using .h5ad object (ST data or scRNA-seq data) and .gct object (bulk RNA-seq data) to generate text embeddings, and examples of using whole image to generate image patch embeddings.

Encode transcriptome from AnnData object (ST data or scRNA-seq data)

Example of encoding transcriptome from .h5ad object. House keeping gene list is obtained from the Molecular Signatures Database (MSigDB): https://www.gsea-msigdb.org/gsea/msigdb/. It takes about 3 mins to run this section using cpu on MacBook Pro.

[11]:
ad_path = os.path.join(data_dir, 'demo_data', 'RZ_GT_P2.h5ad')
ad = sc.read_h5ad(ad_path)
ad
[11]:
AnnData object with n_obs × n_vars = 3538 × 14467
    obs: 'n_counts', 'n_genes', 'percent.mt', 'Adipocyte', 'Cardiomyocyte', 'Endothelial', 'Fibroblast', 'Lymphoid', 'Mast', 'Myeloid', 'Neuronal', 'Pericyte', 'Cycling.cells', 'vSMCs', 'cell_type_original', 'assay_ontology_term_id', 'cell_type_ontology_term_id', 'development_stage_ontology_term_id', 'disease_ontology_term_id', 'ethnicity_ontology_term_id', 'is_primary_data', 'organism_ontology_term_id', 'sex_ontology_term_id', 'tissue_ontology_term_id', 'leiden', 'pixel_y', 'pixel_x', 'cell_type'
    var: 'features'
    uns: 'X_approximate_distribution', 'X_normalization', 'default_embedding', 'leiden', 'neighbors', 'schema_version', 'spatial', 'title'
    obsm: 'X_pca', 'X_spatial', 'X_umap', 'spatial'
    obsp: 'connectivities', 'distances'
[12]:
house_keeping_genes = pd.read_csv(os.path.join(data_dir, 'demo_data', 'housekeeping_genes.csv'), index_col = 0)
top_k_genes_str = loki.preprocess.generate_gene_df(ad, house_keeping_genes)
top_k_genes_str
[12]:
label
RZ_GT_P2_AAACAAGTATCTCCCA-1 TTN MYH7 MB MALAT1 MTRNR2L12 TPM1 DES CRYAB AP...
RZ_GT_P2_AAACAATCTACTAGCA-1 MYH7 MTRNR2L12 TTN CRYAB MYL2 MB TNNT2 DES TPM...
RZ_GT_P2_AAACACCAATAACTGC-1 TTN MYH7 MTRNR2L12 MB MALAT1 CRYAB TNNI3 DES A...
RZ_GT_P2_AAACAGAGCGACTCCT-1 TTN MYH7 MALAT1 MTRNR2L12 MYL2 MB TPM1 MYL3 CR...
RZ_GT_P2_AAACAGCTTTCAGAAG-1 MB MYH7 TTN CRYAB MTRNR2L12 TNNI3 TPM1 MYL2 DE...
... ...
RZ_GT_P2_TTGTTGTGTGTCAAGA-1 DES MB TCAP MYH7 TPM1 FN1 CMYA5 TNNI3 COL3A1 S...
RZ_GT_P2_TTGTTTCACATCCAGG-1 TTN MYH7 MTRNR2L12 MALAT1 MB CRYAB DES TNNC1 M...
RZ_GT_P2_TTGTTTCATTAGTCTA-1 TTN MALAT1 MYH7 MTRNR2L12 MB DES CRYAB TNNI3 T...
RZ_GT_P2_TTGTTTCCATACAACT-1 MTRNR2L12 MB MYH7 TTN CRYAB IGKC DES MALAT1 TP...
RZ_GT_P2_TTGTTTGTGTAAATTC-1 TTN MTRNR2L12 MYH7 TNNI3 MALAT1 MB DES CRYAB T...

3538 rows × 1 columns

[13]:
text_embeddings = loki.utils.encode_text_df(model, tokenizer, top_k_genes_str, 'label', device)
text_embeddings.shape
[13]:
torch.Size([3538, 768])

Encode transcriptome from GCT object (bulk RNA-seq data)

Example of encoding transcriptome from .gct object.

[14]:
sc_data_path = os.path.join(data_dir, 'demo_data', 'bulk_fibroblasts.gct')
gct_data = loki.preprocess.read_gct(sc_data_path)
gct_data
[14]:
id Name Description GTEX-111VG-0008-SM-5Q5BG GTEX-111YS-0008-SM-5Q5BH GTEX-1122O-0008-SM-5QGR2 GTEX-1128S-0008-SM-5Q5DP GTEX-113IC-0008-SM-5QGRF GTEX-113JC-0008-SM-5QGR6 GTEX-117XS-0008-SM-5Q5DQ ... GTEX-ZVE2-0008-SM-51MRU GTEX-ZVP2-0008-SM-51MSL GTEX-ZVT2-0008-SM-57WC9 GTEX-ZVT3-0008-SM-51MRI GTEX-ZVT4-0008-SM-57WCA GTEX-ZVTK-0008-SM-57WDA GTEX-ZVZP-0008-SM-51MSX GTEX-ZVZQ-0008-SM-51MSK GTEX-ZXES-0008-SM-57WCX GTEX-ZXG5-0008-SM-57WDB
0 0 ENSG00000223972.5 DDX11L1 0 0 0 0 0 0 0 ... 0 0 0 0 1 0 0 0 0 0
1 1 ENSG00000227232.5 WASH7P 21 31 16 41 57 46 66 ... 24 62 40 42 64 58 62 52 20 32
2 2 ENSG00000278267.1 MIR6859-1 0 0 0 0 0 0 0 ... 0 0 0 0 0 0 0 0 0 0
3 3 ENSG00000243485.5 MIR1302-2HG 0 0 0 0 1 0 0 ... 1 0 0 0 1 0 0 1 0 0
4 4 ENSG00000237613.2 FAM138A 0 0 1 0 0 0 0 ... 1 0 0 0 0 0 0 0 0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
56195 56195 ENSG00000198695.2 MT-ND6 35691 34685 50128 81334 60140 81957 48581 ... 60506 94854 76311 87410 63940 108716 70445 78721 68057 53557
56196 56196 ENSG00000210194.1 MT-TE 8 19 28 47 31 40 25 ... 35 53 23 34 21 52 29 32 19 27
56197 56197 ENSG00000198727.2 MT-CYB 161038 154424 138836 246421 194695 225034 214565 ... 269451 296274 335286 274560 186433 312236 499192 258581 241541 173705
56198 56198 ENSG00000210195.2 MT-TT 0 1 1 6 2 0 2 ... 1 1 3 3 6 3 3 1 1 0
56199 56199 ENSG00000210196.2 MT-TP 3 0 0 3 1 0 3 ... 1 0 1 1 7 4 1 0 1 2

56200 rows × 507 columns

[15]:
bulk_ad = anndata.AnnData(pd.DataFrame(gct_data.iloc[:, 3:].mean(axis=1)).T)
bulk_ad.var.index = gct_data['Description']
bulk_text_feature = loki.preprocess.generate_gene_df(bulk_ad, house_keeping_genes, todense=False)
bulk_text_feature
/opt/anaconda3/envs/loki_env/lib/python3.9/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
/opt/anaconda3/envs/loki_env/lib/python3.9/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
[15]:
label
0 FN1 COL1A1 COL1A2 COL6A3 TGFBI THBS1 COL6A2 CO...
[16]:
text_embeddings = loki.utils.encode_text_df(model, tokenizer, bulk_text_feature, 'label', device)
text_embeddings.shape
[16]:
torch.Size([1, 768])

Encode patches from a whole image

Example of encoding patches from a whole image. It takes about 28 mins to run this section using cpu on MacBook Pro.

[17]:
coord = pd.read_csv(os.path.join(data_dir, 'demo_data', 'coord.csv'), index_col=0)
coord
[17]:
pixel_x pixel_y
GSM5924033_AAACAACGAATAGTTC-1 1579.906543 152.803738
GSM5924033_AAACAAGTATCTCCCA-1 478.504673 1266.355141
GSM5924033_AAACAATCTACTAGCA-1 1234.112151 219.626168
GSM5924033_AAACACCAATAACTGC-1 1541.121497 1467.289721
GSM5924033_AAACAGAGCGACTCCT-1 581.308412 464.485982
... ... ...
GSM5924033_TTGTTTCACATCCAGG-1 1246.728973 1444.859814
GSM5924033_TTGTTTCATTAGTCTA-1 1400.000001 1489.252338
GSM5924033_TTGTTTCCATACAACT-1 1438.785048 1155.140188
GSM5924033_TTGTTTGTATTACACG-1 1259.345795 1778.971964
GSM5924033_TTGTTTGTGTAAATTC-1 1131.775702 308.878505

4975 rows × 2 columns

[18]:
img = Image.open(os.path.join(data_dir, 'demo_data', 'whole_img.png'))
img
[18]:
../_images/notebooks_basic_usage_27_0.png
[19]:
img_array = np.asarray(img)
patch_dir = os.path.join(data_dir, 'demo_data', 'patch')
loki.preprocess.segment_patches(img_array, coord, patch_dir)
[20]:
img_list = os.listdir(patch_dir)
patch_paths = [os.path.join(patch_dir, fn) for fn in img_list]
[21]:
image_embeddings = loki.utils.encode_images(model, preprocess, patch_paths, device)
image_embeddings.shape
[21]:
torch.Size([4975, 768])
[ ]: