Source code for dance.transforms.mask

import logging
from typing import Literal, Optional

import numpy as np
from scipy.sparse import spmatrix
from scipy.stats import expon

from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform
from dance.typing import Literal, Optional

# @register_preprocessor("split", "entry")
# class CellwiseMaskData(BaseTransform):
#     """Randomly mask data in a cell-wise approach.

#     For every cell that has more than 5 positive counts, mask positive counts according to masking rate and probabiliy
#     generated from distribution.

#     Parameters
#     ----------
#     distr
#         Distribution to generate masks.
#     mask_rate
#         Masking rate.
#     seed:
#         Random seed.
#     Min_gene_counts
#         Minimum number of genes expressed within a below which we do not mask that cell.

#     """

#     _DISPLAY_ATTRS = ("distr", "mask_rate", "seed")

#     def __init__(self, distr: Optional[Literal["exp", "uniform"]] = "exp", mask_rate: Optional[float] = 0.1,
#                  seed: Optional[int] = None, min_gene_counts: int = 5, **kwargs):
#         super().__init__(**kwargs)
#         self.distr = distr
#         self.mask_rate = mask_rate
#         self.seed = seed
#         self.min_gene_counts = min_gene_counts

#     def _get_probs(self, vec):
#         if self.distr == "exp":
#             prob = expon.pdf(vec, 0, 20)
#         elif self.distr == "uniform":
#             prob = np.ones(len(vec))
#         else:
#             raise ValueError(f"Unknown distribution function option {self.distr!r}, "
#                              "available options are: 'exp', 'uniform'")
#         return prob / prob.sum()

#     def __call__(self, data):
#         rng = np.random.default_rng(self.seed)
#         feat = data.get_feature(return_type="sparse")
#         train_mask = np.ones(feat.shape, dtype=bool)

#         for c in range(feat.shape[0]):
#             # Retrieve indices of positive values
#             ind_pos = np.nonzero(feat[c])[-1]
#             cells_c_pos = feat[c, ind_pos]

#             # Get masking probability of each value
#             if cells_c_pos.size > self.min_gene_counts:
#                 prob = self._get_probs(cells_c_pos.toarray()[0])
#                 n_masked = int(np.floor(cells_c_pos.size * self.mask_rate))
#                 if n_masked >= cells_c_pos.size:
#                     self.logger.warning(f"Too many genes masked for cell {c} ({n_masked}/{cells_c_pos.size})")
#                     n_masked = 1 + int(np.floor(0.5 * cells_c_pos.size))

#                 masked_idx = rng.choice(cells_c_pos.size, n_masked, p=prob, replace=False)
#                 train_mask[c, ind_pos[masked_idx]] = False

#         data.data.layers["train_mask"] = train_mask
#         data.data.layers["valid_mask"] = ~train_mask

#         return data


[docs]@register_preprocessor("split", "entry") class CellwiseMaskData(BaseTransform): """Randomly mask data in a cell-wise approach. For every cell that has more than `min_gene_counts` positive counts, mask positive counts according to `mask_rate` and probability generated from the specified distribution. The masked entries are assigned to validation and optionally test masks. Parameters ---------- distr Distribution to generate probabilities for masking counts. Higher counts might have different probabilities depending on the distribution. mask_rate Overall masking rate (proportion of positive counts to mask per cell). seed Random seed for reproducibility. min_gene_counts Minimum number of positive counts within a cell below which we do not mask that cell. add_test_mask If True, the masked entries (determined by `mask_rate`) are further split into validation and test sets. Approximately 10% of the masked entries go to `valid_mask`, and the remaining 90% go to `test_mask`. If False, all masked entries go to `valid_mask`, and `test_mask` will be empty (all False). **kwargs Additional keyword arguments passed to the base class. """ _DISPLAY_ATTRS = ("distr", "mask_rate", "seed", "min_gene_counts", "add_test_mask") def __init__( self, distr: Optional[Literal["exp", "uniform"]] = "exp", mask_rate: Optional[float] = 0.1, seed: Optional[int] = None, min_gene_counts: int = 5, add_test_mask: bool = False, # New parameter **kwargs): super().__init__(**kwargs) self.distr = distr if not 0.0 <= mask_rate <= 1.0: raise ValueError(f"mask_rate must be between 0 and 1, got {mask_rate}") self.mask_rate = mask_rate self.seed = seed self.min_gene_counts = min_gene_counts self.add_test_mask = add_test_mask # Store the new parameter def _get_probs(self, vec): """Calculates sampling probabilities based on the distribution.""" if self.distr == "exp": # Exponential PDF - higher values might be less likely to be masked depending on scale # Using scale=20 as in the original code. Adjust if needed. prob = expon.pdf(vec, 0, 20) elif self.distr == "uniform": # Uniform PDF - all positive counts have equal probability of being masked prob = np.ones(len(vec)) else: raise ValueError(f"Unknown distribution function option {self.distr!r}, " "available options are: 'exp', 'uniform'") # Normalize probabilities if they sum to a positive value prob_sum = prob.sum() if prob_sum > 1e-9: # Avoid division by zero if all probs are effectively zero return prob / prob_sum else: # If sum is zero (e.g., vec was empty or pdf returned all zeros), return uniform probability # This case should ideally be handled before calling _get_probs, but added as a safeguard. self.logger.warning("Probability sum is zero, falling back to uniform probability.") return np.ones(len(vec)) / len(vec) if len(vec) > 0 else np.array([]) def __call__(self, data): """Applies the cell-wise masking. Parameters ---------- data An object containing the feature data (e.g., an AnnData object or similar). Requires a method `get_feature(return_type="sparse")` that returns a scipy sparse matrix (cells x genes), and allows adding layers via `data.data.layers["layer_name"] = mask_array`. Returns ------- data The input data object with added layers: "train_mask", "valid_mask", and "test_mask". """ rng = np.random.default_rng(self.seed) # Assuming get_feature returns a CSR or CSC matrix for efficient row slicing feat = data.get_feature(return_type="sparse") if not isinstance(feat, spmatrix): raise TypeError(f"Expected feature data to be a scipy sparse matrix, got {type(feat)}") n_cells, n_genes = feat.shape train_mask = np.ones((n_cells, n_genes), dtype=bool) valid_mask = np.zeros((n_cells, n_genes), dtype=bool) test_mask = np.zeros((n_cells, n_genes), dtype=bool) # Initialize test mask for c in range(n_cells): # Efficiently get data and indices for the current cell (row) cell_slice = feat[c, :] ind_pos = cell_slice.indices # Indices of non-zero elements in this row cells_c_pos_values = cell_slice.data # Values of non-zero elements num_positive = len(ind_pos) # Only mask if the cell has enough expressed genes if num_positive > self.min_gene_counts: # Calculate number of entries to mask based on the rate n_masked = int(np.floor(num_positive * self.mask_rate)) # Ensure we don't try to mask more than available or zero items if n_masked <= 0: continue # No masking needed for this cell if n_masked >= num_positive: self.logger.warning(f"Mask rate {self.mask_rate} resulted in attempting to mask all " f"{num_positive} positive counts for cell {c}. Reducing mask count.") # Mask roughly half instead of all/too many n_masked = 1 + int(np.floor(0.5 * num_positive)) # Get masking probability for each positive value if needed if self.distr == "exp": # Need the actual values to calculate exp probability prob = self._get_probs(cells_c_pos_values) else: # Uniform distribution prob = None # np.random.choice uses uniform sampling if p is None # Check if probabilities are valid before using them if prob is not None and (len(prob) != num_positive or not np.isclose(prob.sum(), 1.0)): self.logger.warning(f"Invalid probabilities calculated for cell {c}. Falling back to uniform.") prob = None # Fallback to uniform if probabilities are problematic # Choose indices to mask *from the set of positive indices* try: masked_relative_idx = rng.choice( num_positive, # Choose from 0 to num_positive-1 size=n_masked, p=prob, replace=False # Ensure unique indices are masked ) except ValueError as e: # This might happen if probabilities don't sum to 1, or num_positive mismatch self.logger.error(f"Error during rng.choice for cell {c}: {e}. Skipping masking for this cell.") continue # Get the absolute column indices in the feature matrix masked_absolute_indices = ind_pos[masked_relative_idx] # Mark these chosen indices as False in the training mask train_mask[c, masked_absolute_indices] = False # Now, decide where these masked entries go (validation or test) if self.add_test_mask: # Split the masked_absolute_indices into validation and test sets (10% validation, 90% test) n_total_masked_in_cell = len(masked_absolute_indices) if n_total_masked_in_cell > 0: # Calculate the number of validation samples (approx 10%) # Use rounding, ensure at least 1 if possible (unless total is 0) n_valid = int(np.round(n_total_masked_in_cell * 0.1)) #TODO change to 0.7 n_valid = max( 1, n_valid ) if n_total_masked_in_cell > 1 else n_total_masked_in_cell # Ensure at least 1 unless only 1 total # Shuffle the indices before splitting to ensure randomness shuffled_masked_indices = rng.permutation(masked_absolute_indices) # Assign to validation and test masks valid_indices_c = shuffled_masked_indices[:n_valid] test_indices_c = shuffled_masked_indices[n_valid:] valid_mask[c, valid_indices_c] = True test_mask[c, test_indices_c] = True # If n_total_masked_in_cell is 0 (shouldn't happen if n_masked > 0), do nothing else: # Original behavior: all masked entries go to the validation mask valid_mask[c, masked_absolute_indices] = True # Store the masks in the data object's layers # Ensure the data structure supports adding layers like this if not hasattr(data, 'data') or not hasattr(data.data, 'layers'): raise AttributeError("Input data object does not have the expected structure 'data.layers'") data.data.layers["train_mask"] = train_mask data.data.layers["valid_mask"] = valid_mask data.data.layers["test_mask"] = test_mask # Always add test_mask, even if it's all False # Log mask statistics n_total = feat.shape[0] * feat.shape[1] n_train = train_mask.sum() n_valid = valid_mask.sum() n_test = test_mask.sum() n_masked_total = n_valid + n_test self.logger.info(f"Masking complete. Total elements: {n_total}") self.logger.info(f" Train mask: {n_train} elements ({n_train/n_total:.2%})") if self.add_test_mask: self.logger.info(f" Valid mask: {n_valid} elements ({n_valid/n_total:.4%})") self.logger.info(f" Test mask: {n_test} elements ({n_test/n_total:.4%})") if n_masked_total > 0: self.logger.info(f" Validation split of masked: {n_valid / n_masked_total:.2%}") self.logger.info(f" Test split of masked: {n_test / n_masked_total:.2%}") else: self.logger.info(f" Valid mask: {n_valid} elements ({n_valid/n_total:.4%}) (Test mask not created)") return data
[docs]@register_preprocessor("split", "entry") class MaskData(BaseTransform): """Randomly mask data. Randomly mask positive counts according to masking rate. Parameters ---------- mask_rate Masking rate. seed: Random seed. """ _DISPLAY_ATTRS = ("mask_rate", "seed") def __init__(self, mask_rate: Optional[float] = 0.1, seed: Optional[int] = None, **kwargs): super().__init__(**kwargs) self.mask_rate = mask_rate self.seed = seed def _get_probs(self, vec): return { "exp": expon.pdf(vec, 0, 20), "uniform": np.tile([1. / len(vec)], len(vec)), }.get(self.distr) def __call__(self, data): rng = np.random.default_rng(self.seed) feat = data.get_feature(return_type="default") train_mask = np.ones(feat.shape, dtype=bool) row, col = np.nonzero(feat) num_nonzero = len(row) # Randomly mask positive counts according to masking rate. num_train = num_nonzero - int(np.floor(num_nonzero * self.mask_rate)) mask_idx = rng.choice(num_nonzero, size=num_train, replace=False) train_mask[row[mask_idx], col[mask_idx]] = False data.data.layers["train_mask"] = train_mask data.data.layers["valid_mask"] = ~train_mask return data