Source code for dance.transforms.normalize

import numpy as np
import scipy.sparse as sp

from dance.transforms.base import BaseTransform
from dance.typing import Dict, List, Literal, NormMode, Optional, Union
from dance.utils.matrix import normalize


[docs]class ScaleFeature(BaseTransform): """Scale the feature matrix in the AnnData object. This is an extension of :meth:`scanpy.pp.scale`, allowing split- or batch-wide scaling. Parameters ---------- axis Axis along which the scaling is performed. split_names Indicate which splits to perform the scaling independently. If set to 'ALL', then go through all splits available in the data. batch_key Indicate which column in ``.obs`` to use as the batch index to guide the batch-wide scaling. mode Scaling mode, see :meth:`dance.utils.matrix.normalize` for more information. eps Correction fact, see :meth:`dance.utils.matrix.normalize` for more information. Note ---- The order of checking split- or batch-wide scaling mode is: batch_key > split_names > None (i.e., all). """ _DISPLAY_ATTRS = ("axis", "mode", "eps", "split_names", "batch_key") def __init__( self, *, axis: int = 0, split_names: Optional[Union[Literal["ALL"], List[str]]] = None, batch_key: Optional[str] = None, mode: NormMode = "normalize", eps: float = -1, **kwargs, ): super().__init__(**kwargs) self.axis = axis self.split_names = split_names self.batch_key = batch_key self.mode = mode self.eps = eps def _get_idx_dict(self, data) -> List[Dict[str, List[int]]]: batch_key = self.batch_key split_names = self.split_names if batch_key is not None: if split_names is not None: raise ValueError("Exactly one of split_names and batch_key can be specified, got: " f"split_names={split_names!r}, batch_key={batch_key!r}") elif batch_key not in (avail_opts := data.data.obs.columns.tolist()): raise KeyError(f"{batch_key=!r} not found in `.obs`. Available columns are: {avail_opts}") batch_col = data.data.obs[batch_key] idx_dict = {f"batch:{i}": np.where(batch_col[0] == i)[0].tolist() for i in batch_col[0].unique()} return idx_dict if split_names is None: idx_dict = {"full": list(range(data.shape[0]))} elif isinstance(split_names, str) and split_names == "ALL": idx_dict = {f"split:{i}": j for i, j in data._split_idx_dict.items()} elif isinstance(split_names, list): idx_dict = {f"split:{i}": data.get_split_idx(i) for i in split_names} else: raise TypeError(f"Unsupported type {type(split_names)} for split_names: {split_names!r}") return idx_dict def __call__(self, data): if isinstance(data.data.X, sp.spmatrix): self.logger.warning("Native support for sparse matrix is not implemented yet, " "converting to dense array explicitly.") data.data.X = data.data.X.A idx_dict = self._get_idx_dict(data) for name, idx in idx_dict.items(): self.logger.info(f"Scaling {name} (n={len(idx):,})") data.data.X[idx] = normalize(data.data.X[idx], mode=self.mode, axis=self.axis, eps=self.eps)