Source code for dance.transforms.pseudo_gen

from functools import partial

import anndata as ad
import numpy as np
import pandas as pd

from dance import logger as native_logger
from dance.data import Data
from dance.transforms.base import BaseTransform
from dance.typing import Callable, Dict, List, Literal, Logger, Optional, Tuple, Union


[docs]class PseudoMixture(BaseTransform): _DISPLAY_ATTRS = ("n_pseudo", "nc_min", "nc_max", "ct_select") def __init__( self, *, n_pseudo: int = 1000, nc_min: int = 2, nc_max: int = 10, ct_select: Union[Literal["auto"], List[str]] = "auto", ct_key: str = "cellType", channel: Optional[str] = None, channel_type: Optional[str] = "X", random_state: Optional[int] = 0, prefix: str = "ps_mix_", in_split_name: str = "ref", out_split_name: Optional[str] = "pseudo", label_batch: bool = False, **kwargs, ): super().__init__(**kwargs) self.n_pseudo = n_pseudo self.nc_min = nc_min self.nc_max = nc_max self.ct_select = ct_select self.ct_key = ct_key self.channel = channel self.channel_type = channel_type self.random_state = random_state self.prefix = prefix self.in_split_name = in_split_name self.out_split_name = out_split_name self.label_batch = label_batch @staticmethod def gen_mix( x: np.ndarray, annot: np.ndarray, nc_min: int = 2, nc_max: int = 10, clust_vr: str = "cellType", rng: Optional[np.random.Generator] = None, ) -> Tuple[np.ndarray, Dict[str, int], Dict[str, float]]: rng = rng or np.random.default_rng() n_mix = rng.integers(nc_min, nc_max + 1) sample_inds = rng.choice(x.shape[0], size=n_mix, replace=False) mix_counts = x[sample_inds].sum(0) ct_counts_dict = dict(zip(*np.unique(annot[sample_inds], return_counts=True))) info_dict = {"cell_count": n_mix, "total_umi_count": mix_counts.sum()} return mix_counts, ct_counts_dict, info_dict def __call__(self, data): x = data.get_feature(split_name=self.in_split_name, channel=self.channel, channel_type=self.channel_type, return_type="numpy") annot = data.get_feature(split_name=self.in_split_name, channel=self.ct_key, channel_type="obs", return_type="numpy") rng = np.random.default_rng(self.random_state) bouned_gen_mix = partial(self.gen_mix, nc_min=self.nc_min, nc_max=self.nc_max, clust_vr=self.ct_key, rng=rng) mix_x = np.zeros((self.n_pseudo, x.shape[1]), dtype=np.float32) ct_counts_dict_list, ps_info_dict_list = [], [] for i in range(self.n_pseudo): mix_x[i], ct_counts_dict, info_dict = bouned_gen_mix(x, annot) ct_counts_dict_list.append(ct_counts_dict) ps_info_dict_list.append(info_dict) ct_select = get_cell_types(self.ct_select, annot) index_list = [f"{self.prefix}{i}" for i in range(self.n_pseudo)] ct_counts_df = pd.DataFrame(ct_counts_dict_list, columns=ct_select, index=index_list).fillna(0) ct_portion_df = ct_counts_df.div(ct_counts_df.sum(axis=1), axis=0) # normalize to portions obs = pd.DataFrame(ps_info_dict_list, index=index_list) pseudo_data = Data(ad.AnnData(mix_x, obs=obs, var=data.data.var, obsm={"cell_type_portion": ct_portion_df})) data.append(pseudo_data, join="outer", mode="new_split", new_split_name=self.out_split_name, label_batch=self.label_batch)
[docs]class CellTopicProfile(BaseTransform): _DISPLAY_ATTRS = ("ct_select", "ct_key", "split_name", "method") def __init__( self, *, ct_select: Union[Literal["auto"], List[str]] = "auto", ct_key: str = "cellType", batch_key: Optional[str] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: str = "X", method: Literal["median", "mean"] = "median", **kwargs, ): super().__init__(**kwargs) self.ct_select = ct_select self.ct_key = ct_key self.batch_key = batch_key self.split_name = split_name self.channel = channel self.channel_type = channel_type self.method = method def __call__(self, data): x = data.get_feature(split_name=self.split_name, channel=self.channel, channel_type=self.channel_type, return_type="numpy") annot = data.get_feature(split_name=self.split_name, channel=self.ct_key, channel_type="obs", return_type="numpy") if self.batch_key is None: batch_index = None else: batch_index = data.get_feature(split_name=self.split_name, channel=self.batch_key, channel_type="obs", return_type="numpy") ct_select = get_cell_types(self.ct_select, annot) ct_profile = get_ct_profile(x, annot, batch_index=batch_index, ct_select=ct_select, method=self.method, logger=self.logger) ct_profile_df = pd.DataFrame(ct_profile, index=data.data.var_names, columns=ct_select) data.data.varm[self.out] = ct_profile_df
def get_cell_types(ct_select: Union[Literal["auto"], List[str]], annot: np.ndarray) -> List[str]: all_cts = sorted(np.unique(annot)) if ct_select == "auto": ct_select = all_cts elif len(missed := sorted(set(ct_select) - set(all_cts))) > 0: raise ValueError(f"Unknown cell types selected: {missed}. Available options are: {all_cts}") return ct_select def get_agg_func(name: str, *, default: Optional[str] = None) -> Callable[[np.ndarray], np.ndarray]: if name == "default": if default is None: raise ValueError("Aggregation function name set to 'default' but default option not set") name = default if name == "median": agg_func = partial(np.median, axis=0) elif name == "mean": agg_func = partial(np.mean, axis=0) else: raise ValueError(f"Unknown aggregation method {name!r}. Available options are: 'median', 'mena'") return agg_func def get_ct_profile( x: np.ndarray, annot: np.ndarray, *, batch_index: Optional[np.ndarray] = None, ct_select: Union[Literal["auto"], List[str]] = "auto", method: Literal["median", "mean"] = "mean", logger: Optional[Logger] = None, ) -> np.ndarray: logger = logger or native_logger ct_select = get_cell_types(ct_select, annot) agg_func = get_agg_func(method, default="mean") if batch_index is None: batch_index = np.zeros(x.shape[0], dtype=int) # Aggregate profile for each selected cell types logger.info(f"Generating cell-type profiles ({method!r} aggregation) for {ct_select}") ct_profile = np.zeros((x.shape[1], len(ct_select)), dtype=np.float32) # gene x cell for i, ct in enumerate(ct_select): ct_index = np.where(annot == ct)[0] logger.info(f"Aggregating {ct!r} profiles over {ct_index.size:,} samples") # Get features within a cell type sub_batch_index = batch_index[ct_index] batches = np.unique(sub_batch_index) # Aggregate cell type profile for each batch sub_ct_profile = np.zeros((batches.size, x.shape[1]), dtype=np.float32) # cell x gene sub_ct_mean_lib_sizes = np.zeros(batches.size, dtype=np.float32) for j, batch_id in enumerate(batches): idx = np.where(sub_batch_index == batch_id)[0] sub_ct_profile[j] = agg_func(x[ct_index][idx]) sub_ct_mean_lib_sizes[j] = sub_ct_profile[j].sum() sub_ct_profile[j] /= sub_ct_mean_lib_sizes[j] logger.info(f"Number of {ct!r} cells in batch {batch_id!r}: {idx.size:,}") # Aggregate cell type profile over batches ct_profile[:, i] = agg_func(sub_ct_profile) * agg_func(sub_ct_mean_lib_sizes) logger.info("Cell-type profile generated") return ct_profile