Source code for dance.transforms.stats

from pprint import pformat

import pandas as pd

from dance.registers import GENESTATS_FUNCS, register_genestats_func
from dance.transforms.base import BaseTransform
from dance.typing import List, Optional, Union
from dance.utils.wrappers import as_1d_array


[docs]class GeneStats(BaseTransform): """Gene statistics computation. Parameters ---------- genestats_select List of names of the gene stats functions to use. If set to ``"all"`` (by default), then use all available gene stats functions. fill_na If not set (default), then do not fill nans. Otherwise, fill nans with the specified value. threshold Threshold value for filtering gene expression when computing stats, e.g., mean expression values. pseudo If set to ``True``, then add ``1`` to the numerator and denominator when computing the ratio (``alpha``) for which the gene expression values are above the specified ``threshold``. split_name Which split to compute the gene stats on. """ _DISPLAY_ATTRS = ("genestats_select", "threshold", "pseudo", "split_name") def __init__(self, genestats_select: Union[str, List[str]] = "all", *, fill_na: Optional[float] = None, threshold: float = 0, pseudo: bool = False, split_name: Optional[str] = "train", channel: Optional[str] = None, channel_type: Optional[str] = None, **kwargs): super().__init__(**kwargs) # Check genestats options if isinstance(genestats_select, str) and (genestats_select == "all"): self.genestats_select = list(GENESTATS_FUNCS) elif isinstance(genestats_select, list): invalid_options = [i for i in genestats_select if i not in GENESTATS_FUNCS] if invalid_options: raise ValueError(f"The following genestats selections are unavailable:\n{pformat(invalid_options)}\n" f"Currently supported genestats options are {pformat(list(GENESTATS_FUNCS))}") self.genestats_select = genestats_select # Set kwargs to be used by genestats functions self.func_kwargs = { "threshold": threshold, "pseudo": pseudo, } self.fill_na = fill_na self.split_name = split_name # Check expression layer option if (channel is not None) and (channel_type != "layers"): raise ValueError("Only the `layers` channels are available for selection other than the default `.X` " "channel.\nPlease set `channel_type='layers'` to acknowledge this and resolve the error.") self.channel = channel self.channel_type = channel_type def __call__(self, data): exp = data.get_feature(return_type="numpy", split_name=self.split_name, channel=self.channel, channel_type=self.channel_type) self.logger.info(f"Start computing gene stats: {self.genestats_select}") stats_dict = {} for name in self.genestats_select: func = GENESTATS_FUNCS[name] stats_dict[name] = func(exp, **self.func_kwargs) stats_df = pd.DataFrame(stats_dict, index=data.data.var_names) if self.fill_na is not None: stats_df = stats_df.fillna(self.fill_na) data.data.varm[self.out] = stats_df
@register_genestats_func("mu") @as_1d_array def genestats_mu(exp, threshold: float = 0, **kwargs): mask = (exp > threshold).astype(float) mu = (exp * mask).sum(0) / mask.sum(0) return mu @register_genestats_func("alpha") @as_1d_array def genestats_alpha(exp, threshold: float = 0, pseudo: bool = False, **kwargs): mask = (exp > threshold).astype(float) count = mask.sum(0) total = exp.shape[0] if pseudo: count += 1 total += 1 alpha = count / total return alpha @register_genestats_func("mean_all") @as_1d_array def genestats_mean_all(exp, **kwargs): return exp.mean(0) @register_genestats_func("cov_all") @as_1d_array def genestats_cov_all(exp, **kwargs): return exp.std(0) / exp.mean(0) @register_genestats_func("fano_all") @as_1d_array def genestats_fano_all(exp, **kwargs): return exp.var(0) / exp.mean(0) @register_genestats_func("max_all") @as_1d_array def genestats_max_all(exp, **kwargs): return exp.max(0) @register_genestats_func("std_all") @as_1d_array def genestats_std_all(exp, **kwargs): return exp.std(0)