import os
import os.path as osp
import pickle
from abc import ABC
import anndata as ad
import mudata as md
import numpy as np
import scanpy as sc
from dance import logger
from dance.data import Data
from dance.datasets.base import BaseDataset
from dance.transforms.preprocess import lsiTransformer
from dance.typing import List
from dance.utils.download import download_file, unzip_file
[docs]class MultiModalityDataset(BaseDataset, ABC):
TASK = "N/A"
URL_DICT = {}
SUBTASK_NAME_MAP = {}
AVAILABLE_DATA = []
def __init__(self, subtask, root="./data"):
assert subtask in self.AVAILABLE_DATA, f"Undefined subtask {subtask!r}."
assert self.TASK in ["predict_modality", "match_modality", "joint_embedding"]
self.subtask = self.SUBTASK_NAME_MAP.get(subtask, subtask)
self.data_url = self.URL_DICT[self.subtask]
super().__init__(root=root, full_download=False)
[docs] def download(self):
self.download_data()
def download_data(self):
download_file(self.data_url, osp.join(self.root, f"{self.subtask}.zip"))
unzip_file(osp.join(self.root, f"{self.subtask}.zip"), self.root)
def download_pathway(self):
download_file("https://www.dropbox.com/s/uqoakpalr3albiq/h.all.v7.4.entrez.gmt?dl=1",
osp.join(self.root, "h.all.v7.4.entrez.gmt"))
download_file("https://www.dropbox.com/s/yjrcsd2rpmahmfo/h.all.v7.4.symbols.gmt?dl=1",
osp.join(self.root, "h.all.v7.4.symbols.gmt"))
@property
def data_paths(self) -> List[str]:
if self.TASK == "joint_embedding":
mod = "adt" if "cite" in self.subtask else "atac"
meta = "cite" if "cite" in self.subtask else "multiome"
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{meta}_gex_processed_training.h5ad"),
osp.join(self.root, self.subtask, f"{meta}_{mod}_processed_training.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_solution.h5ad"),
]
elif self.TASK == "predict_modality":
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod2.h5ad")
]
if self.subtask == "10k_pbmc":
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.10kanti_dataset_subset.output_test_mod2.h5ad")
]
if self.subtask == "pbmc_cite":
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.citeanti_dataset.output_test_mod2.h5ad")
]
if self.subtask.startswith("5k_pbmc"):
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.5kanti_dataset.output_test_mod2.h5ad")
]
if self.subtask.startswith("openproblems_2022"):
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.open_dataset.output_test_mod2.h5ad")
]
if self.subtask.startswith("GSE127064"):
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE126074_dataset.output_test_mod2.h5ad")
]
if self.subtask.startswith("GSE117089"):
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE117089_dataset.output_test_mod2.h5ad")
]
if self.subtask.startswith("GSE140203"):
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.GSE140203_dataset.output_test_mod2.h5ad")
]
elif self.TASK == "match_modality":
paths = [
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_train_sol.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod1.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_mod2.h5ad"),
osp.join(self.root, self.subtask, f"{self.subtask}.censor_dataset.output_test_sol.h5ad"),
]
return paths
[docs] def is_complete(self) -> bool:
return all(map(osp.exists, self.data_paths))
def _load_raw_data(self) -> List[ad.AnnData]:
modalities = []
for mod_path in self.data_paths:
logger.info(f"Loading {mod_path}")
modalities.append(ad.read_h5ad(mod_path))
return modalities
[docs]class ModalityPredictionDataset(MultiModalityDataset):
TASK = "predict_modality"
URL_DICT = {
"openproblems_bmmc_cite_phase2_mod2":
"https://www.dropbox.com/s/snh8knscnlcq4um/openproblems_bmmc_cite_phase2_mod2.zip?dl=1",
"openproblems_bmmc_cite_phase2_rna":
"https://www.dropbox.com/s/xbfyhv830u9pupv/openproblems_bmmc_cite_phase2_rna.zip?dl=1",
"openproblems_bmmc_multiome_phase2_mod2":
"https://www.dropbox.com/s/p9ve2ljyy4yqna4/openproblems_bmmc_multiome_phase2_mod2.zip?dl=1",
"openproblems_bmmc_multiome_phase2_rna":
"https://www.dropbox.com/s/cz60vp7bwapz0kw/openproblems_bmmc_multiome_phase2_rna.zip?dl=1",
"openproblems_bmmc_cite_phase2_rna_subset":
"https://www.dropbox.com/s/veytldxkgzyoa8j/openproblems_bmmc_cite_phase2_rna_subset.zip?dl=1",
"5k_pbmc":
"https://www.dropbox.com/scl/fi/uoyis946glh0oo7g833qj/5k_pbmc.zip?rlkey=mw9cvqq7e12iowfbr9rp7av5u&dl=1",
"5k_pbmc_subset":
"https://www.dropbox.com/scl/fi/pykqc9zyt1fjypnjf4m1l/5k_pbmc_subset.zip?rlkey=brkmnqhfz5yl9axiuu0f8gmxy&dl=1",
"10k_pbmc":
"https://www.dropbox.com/scl/fi/npz3n36d3w089creppph2/10k_pbmc.zip?rlkey=6yyv61omv2rw7sqqmfp6u7m1s&dl=1",
"pbmc_cite":
"https://www.dropbox.com/scl/fi/8yvel9lu2f4pbemjeihzq/pbmc_cite.zip?rlkey=5f5jpjy1fcg14hwzot0hot7xd&dl=1",
"openproblems_2022_multi_atac2gex":
"https://www.dropbox.com/scl/fi/4ynxepu306g3k6vqpi3aw/openproblems_2022_multi_atac2gex.zip?rlkey=2mq5vjnsh26gg5zgq9d85ikcp&dl=1",
"openproblems_2022_cite_gex2adt":
"https://www.dropbox.com/scl/fi/dalt3qxwe440107ihjbpy/openproblems_2022_cite_gex2adt.zip?rlkey=ps1fvcr622vhibc1wc1umfdaw&dl=1",
"GSE127064_AdBrain_gex2atac":
"https://www.dropbox.com/scl/fi/4ybsx6pgiuy6j9m0y92ly/GSE127064_AdBrain_gex2atac.zip?rlkey=6a5u7p7xr2dqsoduflzxjluja&dl=1",
"GSE127064_p0Brain_gex2atac":
"https://www.dropbox.com/scl/fi/k4p3nkkqq56ev6ljyo5se/GSE127064_p0Brain_gex2atac.zip?rlkey=y7kayqmk2l72jjogzlvfxtl74&dl=1",
"GSE117089_mouse_gex2atac":
"https://www.dropbox.com/scl/fi/egktuwiognr06xebeuouk/GSE117089_mouse_gex2atac.zip?rlkey=jadp3hlopc3112lmxe6nz5cd1&dl=1",
"GSE117089_A549_gex2atac":
"https://www.dropbox.com/scl/fi/b7evc2n5ih5o3xxwcd7uq/GSE117089_A549_gex2atac.zip?rlkey=b5o0ykptfodim59qwnu2m89fh&dl=1",
"GSE117089_sciCAR_gex2atac":
"https://www.dropbox.com/scl/fi/juibpvmtv2otvfsq1xyr7/GSE117089_sciCAR_gex2atac.zip?rlkey=qcdbfqsuhab56bc553cwm78gc&dl=1",
"GSE140203_3T3_HG19_atac2gex":
"https://www.dropbox.com/scl/fi/v1vbypz87t1rz012vojkh/GSE140203_3T3_HG19_atac2gex.zip?rlkey=xmxrwso5e5ty3w53ctbm5bo9z&dl=1",
"GSE140203_3T3_MM10_atac2gex":
"https://www.dropbox.com/scl/fi/po9k064twny51subze6df/GSE140203_3T3_MM10_atac2gex.zip?rlkey=q0b4y58bsvacnjrmvsclk4jqu&dl=1",
"GSE140203_12878.rep2_atac2gex":
"https://www.dropbox.com/scl/fi/jqijimb7h6cv4w4hkax1q/GSE140203_12878.rep2_atac2gex.zip?rlkey=c837xkoacap4wjszffpfrmuak&dl=1",
"GSE140203_12878.rep3_atac2gex":
"https://www.dropbox.com/scl/fi/wlv9dhvylz78kq8ezncmd/GSE140203_12878.rep3_atac2gex.zip?rlkey=5r607plnqzlxdgxtc4le8d6o1&dl=1",
"GSE140203_K562_HG19_atac2gex":
"https://www.dropbox.com/scl/fi/n2he1br3u604p3mgniowz/GSE140203_K562_HG19_atac2gex.zip?rlkey=2lhe7s5run8ly5uk4b0vfemyj&dl=1",
"GSE140203_K562_MM10_atac2gex":
"https://www.dropbox.com/scl/fi/dhdorqy87915uah3xl07a/GSE140203_K562_MM10_atac2gex.zip?rlkey=ecwsy5sp7f2i2gtjo1qyaf4zt&dl=1",
"GSE140203_LUNG_atac2gex":
"https://www.dropbox.com/scl/fi/gabugiw244ky85j3ckq4d/GSE140203_LUNG_atac2gex.zip?rlkey=uj0we276s6ay2acpioj4tmfj3&dl=1"
}
SUBTASK_NAME_MAP = {
"adt2gex": "openproblems_bmmc_cite_phase2_mod2",
"atac2gex": "openproblems_bmmc_multiome_phase2_mod2",
"gex2adt": "openproblems_bmmc_cite_phase2_rna",
"gex2atac": "openproblems_bmmc_multiome_phase2_rna",
"gex2adt_subset": "openproblems_bmmc_cite_phase2_rna_subset",
}
AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP))
def __init__(self, subtask, root="./data", preprocess=None, span=0.3):
# TODO: factor our preprocess
self.preprocess = preprocess
self.span = span
super().__init__(subtask, root)
def _raw_to_dance(self, raw_data):
train_mod1, train_mod2, test_mod1, test_mod2 = self._maybe_preprocess(raw_data)
mod1 = ad.concat((train_mod1, test_mod1))
mod2 = ad.concat((train_mod2, test_mod2))
mod1.var_names_make_unique()
mod2.var_names_make_unique()
mdata = md.MuData({"mod1": mod1, "mod2": mod2})
mdata.var_names_make_unique()
data = Data(mdata, train_size=train_mod1.shape[0])
data.set_config(feature_mod="mod1", label_mod="mod2")
return data
def _maybe_preprocess(self, raw_data, selection_threshold=10000):
if self.preprocess == "feature_selection":
if raw_data[0].shape[1] > selection_threshold:
sc.pp.highly_variable_genes(raw_data[0], layer="counts", flavor="seurat_v3",
n_top_genes=selection_threshold, span=self.span)
raw_data[2].var["highly_variable"] = raw_data[0].var["highly_variable"]
for i in [0, 2]:
raw_data[i] = raw_data[i][:, raw_data[i].var["highly_variable"]]
elif self.preprocess not in (None, "none"):
logger.info(f"Preprocessing method {self.preprocess!r} not supported.")
logger.info("Preprocessing done.")
return raw_data
[docs]class ModalityMatchingDataset(MultiModalityDataset):
TASK = "match_modality"
URL_DICT = {
"openproblems_bmmc_cite_phase2_mod2":
"https://www.dropbox.com/s/fa6zut89xx73itz/openproblems_bmmc_cite_phase2_mod2.zip?dl=1",
"openproblems_bmmc_cite_phase2_rna":
"https://www.dropbox.com/s/ep00mqcjmdu0b7v/openproblems_bmmc_cite_phase2_rna.zip?dl=1",
"openproblems_bmmc_multiome_phase2_mod2":
"https://www.dropbox.com/s/31qi5sckx768acw/openproblems_bmmc_multiome_phase2_mod2.zip?dl=1",
"openproblems_bmmc_multiome_phase2_rna":
"https://www.dropbox.com/s/h1s067wkefs1jh2/openproblems_bmmc_multiome_phase2_rna.zip?dl=1",
"openproblems_bmmc_cite_phase2_rna_subset":
"https://www.dropbox.com/s/3q4xwpzjbe81x58/openproblems_bmmc_cite_phase2_rna_subset.zip?dl=1",
}
SUBTASK_NAME_MAP = {
"adt2gex": "openproblems_bmmc_cite_phase2_mod2",
"atac2gex": "openproblems_bmmc_multiome_phase2_mod2",
"gex2adt": "openproblems_bmmc_cite_phase2_rna",
"gex2atac": "openproblems_bmmc_multiome_phase2_rna",
"gex2adt_subset": "openproblems_bmmc_cite_phase2_rna_subset",
}
AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP))
def __init__(self, subtask, root="./data", preprocess=None, pkl_path=None, span=0.3):
# TODO: factor our preprocess
self.preprocess = preprocess
self.pkl_path = pkl_path
self.span = span
super().__init__(subtask, root)
def _raw_to_dance(self, raw_data):
train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = self._maybe_preprocess(raw_data)
# Align matched cells
train_mod2 = train_mod2[train_label.to_df().values.argmax(1)]
mod1 = ad.concat((train_mod1, test_mod1))
mod2 = ad.concat((train_mod2, test_mod2))
mod1.var_names_make_unique()
mod2.var_names_make_unique()
mod2.obs_names = mod1.obs_names
train_size = train_mod1.shape[0]
mod1.obsm["labels"] = np.concatenate([np.zeros(train_size), np.argmax(test_label.X.toarray(), 1)])
# Combine modalities into mudata
mdata = md.MuData({"mod1": mod1, "mod2": mod2})
mdata.var_names_make_unique()
data = Data(mdata, train_size=train_size)
return data
def _maybe_preprocess(self, raw_data, selection_threshold=10000):
if self.preprocess is None:
return raw_data
train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label = raw_data
modalities = [train_mod1, train_mod2, test_mod1, test_mod2]
# TODO: support other two subtasks
assert self.subtask in ("openproblems_bmmc_cite_phase2_rna", "openproblems_bmmc_cite_phase2_rna_subset",
"openproblems_bmmc_multiome_phase2_rna"), "Currently not available."
if self.preprocess == "pca":
if self.pkl_path and osp.exists(self.pkl_path):
with open(self.pkl_path, "rb") as f:
preprocessed_features = pickle.load(f)
else:
if self.subtask in ("openproblems_bmmc_cite_phase2_rna", "openproblems_bmmc_cite_phase2_rna_subset"):
lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True)
m1_train = lsi_transformer_gex.fit_transform(modalities[0]).values
m1_test = lsi_transformer_gex.transform(modalities[2]).values
m2_train = modalities[1].X.toarray()
m2_test = modalities[3].X.toarray()
elif self.subtask == "openproblems_bmmc_multiome_phase2_rna":
lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True)
m1_train = lsi_transformer_gex.fit_transform(modalities[0]).values
m1_test = lsi_transformer_gex.transform(modalities[2]).values
lsi_transformer_atac = lsiTransformer(n_components=512, drop_first=True)
m2_train = lsi_transformer_atac.fit_transform(modalities[1]).values
m2_test = lsi_transformer_atac.transform(modalities[3]).values
else:
raise ValueError(f"Unrecognized subtask name: {self.subtask}")
preprocessed_features = {
"mod1_train": m1_train,
"mod2_train": m2_train,
"mod1_test": m1_test,
"mod2_test": m2_test
}
if self.pkl_path:
with open(self.pkl_path, "wb") as f:
pickle.dump(preprocessed_features, f)
modalities[0].obsm["X_pca"] = preprocessed_features["mod1_train"]
modalities[1].obsm["X_pca"] = preprocessed_features["mod2_train"]
modalities[2].obsm["X_pca"] = preprocessed_features["mod1_test"]
modalities[3].obsm["X_pca"] = preprocessed_features["mod2_test"]
elif self.preprocess == "feature_selection":
for i in range(2):
if modalities[i].shape[1] > selection_threshold:
sc.pp.highly_variable_genes(modalities[i], layer="counts", flavor="seurat_v3",
n_top_genes=selection_threshold, span=self.span)
modalities[i + 2].var["highly_variable"] = modalities[i].var["highly_variable"]
modalities[i] = modalities[i][:, modalities[i].var["highly_variable"]]
modalities[i + 2] = modalities[i + 2][:, modalities[i + 2].var["highly_variable"]]
else:
logger.info("Preprocessing method not supported.")
logger.info("Preprocessing done.")
train_mod1, train_mod2, test_mod1, test_mod2 = modalities
return train_mod1, train_mod2, train_label, test_mod1, test_mod2, test_label
[docs]class JointEmbeddingNIPSDataset(MultiModalityDataset):
TASK = "joint_embedding"
URL_DICT = {
"openproblems_bmmc_cite_phase2":
"https://www.dropbox.com/s/hjr4dxuw55vin5z/openproblems_bmmc_cite_phase2.zip?dl=1",
"openproblems_bmmc_multiome_phase2":
"https://www.dropbox.com/s/40kjslupxhkg92s/openproblems_bmmc_multiome_phase2.zip?dl=1"
}
SUBTASK_NAME_MAP = {
"adt": "openproblems_bmmc_cite_phase2",
"atac": "openproblems_bmmc_multiome_phase2",
}
AVAILABLE_DATA = sorted(list(URL_DICT) + list(SUBTASK_NAME_MAP))
def __init__(self, subtask, root="./data", preprocess=None, normalize=False, pretrained_folder="."):
# TODO: factor our preprocess
self.preprocess = preprocess
self.normalize = normalize
self.pretrained_folder = pretrained_folder
super().__init__(subtask, root)
def _raw_to_dance(self, raw_data):
mod1, mod2, meta1, meta2, test_sol = self._maybe_preprocess(raw_data)
assert all(mod2.obs_names == mod1.obs_names), "Modalities not aligned"
mdata = md.MuData({"mod1": mod1, "mod2": mod2, "meta1": meta1, "meta2": meta2, "test_sol": test_sol})
train_size = meta1.shape[0]
data = Data(mdata, train_size=train_size)
return data
def _maybe_preprocess(self, raw_data, selection_threshold=10000):
if self.preprocess is None:
return raw_data
mod1, mod2, meta1, meta2, test_sol = raw_data
train_size = meta1.shape[0]
# aux -> cell cycle analysis
if self.preprocess == "aux":
os.makedirs(self.pretrained_folder, exist_ok=True)
if osp.exists(osp.join(self.pretrained_folder, f"preprocessed_data_{self.subtask}.pkl")):
with open(osp.join(self.pretrained_folder, f"preprocessed_data_{self.subtask}.pkl"), "rb") as f:
preprocessed_data = pickle.load(f)
y_train = preprocessed_data["y_train"]
mod1.obsm["X_pca"] = preprocessed_data["X_pca_0"]
mod2.obsm["X_pca"] = preprocessed_data["X_pca_1"]
mod1.obsm["cell_type"] = y_train[0]
mod1.obsm["batch_label"] = np.concatenate(
[y_train[1], np.zeros(y_train[0].shape[0] - train_size)], 0)
mod1.obsm["phase_labels"] = np.concatenate(
[y_train[2], np.zeros(y_train[0].shape[0] - train_size)], 0)
mod1.obsm["S_scores"] = np.concatenate([y_train[3], np.zeros(y_train[0].shape[0] - train_size)], 0)
mod1.obsm["G2M_scores"] = np.concatenate(
[y_train[4], np.zeros(y_train[0].shape[0] - train_size)], 0)
with open(osp.join(self.pretrained_folder, f"{self.subtask}_config.pk"), "rb") as f:
# cell types, batch labels, cell cycle
self.nb_cell_types, self.nb_batches, self.nb_phases = pickle.load(f)
logger.info("Preprocessing done.")
return mod1, mod2, meta1, meta2, test_sol
# PCA
mod1_name = mod1.var["feature_types"][0]
mod2_name = mod2.var["feature_types"][0]
if mod2_name == "ADT":
if osp.exists(osp.join(self.pretrained_folder, f"lsi_cite_{mod1_name}.pkl")):
with open(osp.join(self.pretrained_folder, f"lsi_cite_{mod1_name}.pkl"), "rb") as f:
lsi_transformer_gex = pickle.load(f)
else:
lsi_transformer_gex = lsiTransformer(n_components=256, drop_first=True)
lsi_transformer_gex.fit(mod1)
with open(osp.join(self.pretrained_folder, f"lsi_cite_{mod1_name}.pkl"), "wb") as f:
pickle.dump(lsi_transformer_gex, f)
if mod2_name == "ATAC":
if osp.exists(osp.join(self.pretrained_folder, f"lsi_multiome_{mod1_name}.pkl")):
with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod1_name}.pkl"), "rb") as f:
lsi_transformer_gex = pickle.load(f)
else:
lsi_transformer_gex = lsiTransformer(n_components=64, drop_first=True)
lsi_transformer_gex.fit(mod1)
with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod1_name}.pkl"), "wb") as f:
pickle.dump(lsi_transformer_gex, f)
if osp.exists(osp.join(self.pretrained_folder, f"lsi_multiome_{mod2_name}.pkl")):
with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod2_name}.pkl"), "rb") as f:
lsi_transformer_atac = pickle.load(f)
else:
# lsi_transformer_atac = TruncatedSVD(n_components=100, random_state=random_seed)
lsi_transformer_atac = lsiTransformer(n_components=512, drop_first=True)
lsi_transformer_atac.fit(mod2)
with open(osp.join(self.pretrained_folder, f"lsi_multiome_{mod2_name}.pkl"), "wb") as f:
pickle.dump(lsi_transformer_atac, f)
# Data preprocessing
# Only exploration dataset provides cell type information.
# The exploration dataset is a subset of the full dataset.
ad_mod1 = meta1
mod1_obs = ad_mod1.obs
# Make sure exploration data match the full data
assert ((mod1.obs["batch"].index[:mod1_obs.shape[0]] == mod1_obs["batch"].index).mean() == 1)
if mod2_name == "ADT":
mod1_pca = lsi_transformer_gex.transform(mod1).values
mod2_pca = mod2.X.toarray()
elif mod2_name == "ATAC":
mod1_pca = lsi_transformer_gex.transform(mod1).values
mod2_pca = lsi_transformer_atac.transform(mod2).values
else:
raise ValueError(f"Unknown modality 2: {mod2_name}")
cell_cycle_genes = [
"MCM5", "PCNA", "TYMS", "FEN1", "MCM2", "MCM4", "RRM1", "UNG", "GINS2", "MCM6", "CDCA7", "DTL", "PRIM1",
"UHRF1", "MLF1IP", "HELLS", "RFC2", "RPA2", "NASP", "RAD51AP1", "GMNN", "WDR76", "SLBP", "CCNE2",
"UBR7", "POLD3", "MSH2", "ATAD2", "RAD51", "RRM2", "CDC45", "CDC6", "EXO1", "TIPIN", "DSCC1", "BLM",
"CASP8AP2", "USP1", "CLSPN", "POLA1", "CHAF1B", "BRIP1", "E2F8", "HMGB2", "CDK1", "NUSAP1", "UBE2C",
"BIRC5", "TPX2", "TOP2A", "NDC80", "CKS2", "NUF2", "CKS1B", "MKI67", "TMPO", "CENPF", "TACC3", "FAM64A",
"SMC4", "CCNB2", "CKAP2L", "CKAP2", "AURKB", "BUB1", "KIF11", "ANP32E", "TUBB4B", "GTSE1", "KIF20B",
"HJURP", "CDCA3", "HN1", "CDC20", "TTK", "CDC25C", "KIF2C", "RANGAP1", "NCAPD2", "DLGAP5", "CDCA2",
"CDCA8", "ECT2", "KIF23", "HMMR", "AURKA", "PSRC1", "ANLN", "LBR", "CKAP5", "CENPE", "CTCF", "NEK2",
"G2E3", "GAS2L3", "CBX5", "CENPA"
]
logger.info(f"Data loading and pca done: {mod1_pca.shape=}, {mod2_pca.shape=}")
logger.info("Start to calculate cell_cycle score. It may roughly take an hour.")
cell_type_labels = test_sol.obs["cell_type"].to_numpy()
batch_ids = mod1_obs["batch"]
phase_labels = mod1_obs["phase"]
nb_cell_types = len(np.unique(cell_type_labels))
nb_batches = len(np.unique(batch_ids))
nb_phases = len(np.unique(phase_labels)) - 1 # 2
cell_type_labels_unique = list(np.unique(cell_type_labels))
batch_ids_unique = list(np.unique(batch_ids))
phase_labels_unique = list(np.unique(phase_labels))
c_labels = np.array([cell_type_labels_unique.index(item) for item in cell_type_labels])
b_labels = np.array([batch_ids_unique.index(item) for item in batch_ids])
p_labels = np.array([phase_labels_unique.index(item) for item in phase_labels])
# 0:G1, 1:G2M, 2: S, only consider the last two
s_genes = cell_cycle_genes[:43]
g2m_genes = cell_cycle_genes[43:]
sc.pp.log1p(ad_mod1)
sc.pp.scale(ad_mod1)
sc.tl.score_genes_cell_cycle(ad_mod1, s_genes=s_genes, g2m_genes=g2m_genes)
S_scores = ad_mod1.obs["S_score"].values
G2M_scores = ad_mod1.obs["G2M_score"].values
# phase_scores = np.stack([S_scores, G2M_scores]).T # (nb_cells, 2)
y_train = [c_labels, b_labels, p_labels, S_scores, G2M_scores]
mod1.obsm["X_pca"] = mod1_pca
mod2.obsm["X_pca"] = mod2_pca
train_size = mod1_obs.shape[0]
mod1.obsm["cell_type"] = c_labels
mod1.obsm["batch_label"] = np.concatenate([y_train[1], np.zeros(mod1.shape[0] - train_size)], 0)
mod1.obsm["phase_labels"] = np.concatenate([y_train[2], np.zeros(mod1.shape[0] - train_size)], 0)
mod1.obsm["S_scores"] = np.concatenate([y_train[3], np.zeros(mod1.shape[0] - train_size)], 0)
mod1.obsm["G2M_scores"] = np.concatenate([y_train[4], np.zeros(mod1.shape[0] - train_size)], 0)
preprocessed_data = {"X_pca_0": mod1.obsm["X_pca"], "X_pca_1": mod2.obsm["X_pca"], "y_train": y_train}
with open(osp.join(self.pretrained_folder, f"preprocessed_data_{self.subtask}.pkl"), "wb") as f:
pickle.dump(preprocessed_data, f)
with open(osp.join(self.pretrained_folder, f"{self.subtask}_config.pk"), "wb") as f:
pickle.dump([nb_cell_types, nb_batches, nb_phases], f)
self.nb_cell_types, self.nb_batches, self.nb_phases = nb_cell_types, nb_batches, nb_phases
elif self.preprocess == "feature_selection":
if mod1.shape[1] > selection_threshold:
sc.pp.highly_variable_genes(mod1, layer="counts", flavor="seurat_v3", n_top_genes=selection_threshold)
mod1 = mod1[:, mod1.var["highly_variable"]]
if mod2.shape[1] > selection_threshold:
sc.pp.highly_variable_genes(mod2, layer="counts", flavor="seurat_v3", n_top_genes=selection_threshold)
mod2 = mod2[:, mod2.var["highly_variable"]]
else:
logger.info(f"Preprocessing method {self.preprocess!r} not supported.")
# Normalization
if self.normalize:
sc.pp.scale(mod1)
sc.pp.scale(mod2)
logger.info("Preprocessing done.")
return mod1, mod2, meta1, meta2, test_sol