Source code for dance.transforms.mask

import numpy as np
from scipy.stats import expon

from dance.transforms.base import BaseTransform
from dance.typing import Literal, Optional


[docs]class CellwiseMaskData(BaseTransform): """Randomly mask data in a cell-wise approach. For every cell that has more than 5 positive counts, mask positive counts according to masking rate and probabiliy generated from distribution. Parameters ---------- distr Distribution to generate masks. mask_rate Masking rate. seed: Random seed. Min_gene_counts Minimum number of genes expressed within a below which we do not mask that cell. """ _DISPLAY_ATTRS = ("distr", "mask_rate", "seed") def __init__(self, distr: Optional[Literal["exp", "uniform"]] = "exp", mask_rate: Optional[float] = 0.1, seed: Optional[int] = None, min_gene_counts: int = 5, **kwargs): super().__init__(**kwargs) self.distr = distr self.mask_rate = mask_rate self.seed = seed self.min_gene_counts = min_gene_counts def _get_probs(self, vec): if self.distr == "exp": prob = expon.pdf(vec, 0, 20) elif self.distr == "uniform": prob = np.ones(len(vec)) else: raise ValueError(f"Unknown distribution function option {self.distr!r}, " "available options are: 'exp', 'uniform'") return prob / prob.sum() def __call__(self, data): rng = np.random.default_rng(self.seed) feat = data.get_feature(return_type="default") train_mask = np.ones(feat.shape, dtype=bool) for c in range(feat.shape[0]): # Retrieve indices of positive values ind_pos = np.nonzero(feat[c])[-1] cells_c_pos = feat[c, ind_pos] # Get masking probability of each value if cells_c_pos.size > self.min_gene_counts: prob = self._get_probs(cells_c_pos.toarray()[0]) n_masked = int(np.floor(cells_c_pos.size * self.mask_rate)) if n_masked >= cells_c_pos.size: self.logger.warning(f"Too many genes masked for cell {c} ({n_masked}/{cells_c_pos.size})") n_masked = 1 + int(np.floor(0.5 * cells_c_pos.size)) masked_idx = rng.choice(cells_c_pos.size, n_masked, p=prob, replace=False) train_mask[c, ind_pos[masked_idx]] = False data.data.layers["train_mask"] = train_mask data.data.layers["valid_mask"] = ~train_mask return data
[docs]class MaskData(BaseTransform): """Randomly mask data. Randomly mask positive counts according to masking rate. Parameters ---------- mask_rate Masking rate. seed: Random seed. """ _DISPLAY_ATTRS = ("mask_rate", "seed") def __init__(self, mask_rate: Optional[float] = 0.1, seed: Optional[int] = None, **kwargs): super().__init__(**kwargs) self.mask_rate = mask_rate self.seed = seed def _get_probs(self, vec): return { "exp": expon.pdf(vec, 0, 20), "uniform": np.tile([1. / len(vec)], len(vec)), }.get(self.distr) def __call__(self, data): rng = np.random.default_rng(self.seed) feat = data.get_feature(return_type="default") train_mask = np.ones(feat.shape, dtype=bool) row, col = np.nonzero(feat) num_nonzero = len(row) # Randomly mask positive counts according to masking rate. num_train = num_nonzero - int(np.floor(num_nonzero * self.mask_rate)) mask_idx = rng.choice(num_nonzero, size=num_train, replace=False) train_mask[row[mask_idx], col[mask_idx]] = False data.data.layers["train_mask"] = train_mask data.data.layers["valid_mask"] = ~train_mask return data