Source code for dance.transforms.filter

import warnings
from abc import ABC
from typing import get_args

import anndata as ad
import mudata as md
import numpy as np
import pandas as pd
import scanpy as sc
import scipy
import scipy.sparse as sp
from scipy.stats import median_abs_deviation, rankdata
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import MinMaxScaler, PolynomialFeatures

from dance import logger as default_logger
from dance.data.base import Data
from dance.exceptions import DevError
from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform
from dance.transforms.interface import AnnDataTransform
from dance.typing import Dict, GeneSummaryMode, List, Literal, Logger, Optional, Tuple, Union
from dance.utils import default
from dance.utils.status import deprecated
from dance.utils.wrappers import add_mod_and_transform


def get_count(count_or_ratio: Optional[Union[float, int]], total: int) -> Optional[int]:
    """Get the count from a count or ratio.

    Parameters
    ----------
    count_or_ratio
        Either a count or a ratio. If None, then return None.
    total
        Total number.

    """
    if count_or_ratio is None:
        return None
    elif isinstance(count_or_ratio, float):
        if count_or_ratio > 1.:
            raise ValueError(f"{count_or_ratio=} is greater than 1. Ratio cannot be greater than 1.")
        return int(count_or_ratio * total)
    elif isinstance(count_or_ratio, int):
        if count_or_ratio > total:
            raise ValueError(f"{count_or_ratio=} is greater than {total=}")
        return count_or_ratio
    else:
        raise TypeError(f"count_or_ratio must be either float or int, got {type(count_or_ratio)}")


[docs]@register_preprocessor("filter") @add_mod_and_transform class FilterScanpy(BaseTransform): """Scanpy filtering transformation with additional options.""" _FILTER_TARGET: Optional[Literal["cells", "genes"]] = None def __init__( self, min_counts: Optional[Union[float, int]] = None, min_genes_or_cells: Optional[Union[float, int]] = None, max_counts: Optional[Union[float, int]] = None, max_genes_or_cells: Optional[Union[float, int]] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", key_n_counts: Optional[str] = None, key_n_genes_or_cells: Optional[str] = None, inplace=True, **kwargs, ): super().__init__(**kwargs) self.min_counts = min_counts self.min_genes_or_cells = min_genes_or_cells self.max_counts = max_counts self.max_genes_or_cells = max_genes_or_cells self.key_n_counts = key_n_counts self.split_name = split_name self.channel = channel self.channel_type = channel_type self.key_n_genes_or_cells = key_n_genes_or_cells self.inplace = inplace if self._FILTER_TARGET is None: raise NotImplementedError("Use FilterCellsScanpy or FilterGenesScanpy instead") elif self._FILTER_TARGET == "cells": self._subsetting_func_name = "_inplace_subset_obs" self._filter_func = sc.pp.filter_cells self.min_genes = min_genes_or_cells self.max_genes = max_genes_or_cells elif self._FILTER_TARGET == "genes": self._subsetting_func_name = "_inplace_subset_var" self._filter_func = sc.pp.filter_genes self.min_cells = min_genes_or_cells self.max_cells = max_genes_or_cells else: raise ValueError(f"Unknown filter target {self._FILTER_TARGET!r}") def __call__(self, data): x = data.get_feature(return_type="numpy", split_name=self.split_name, channel=self.channel, channel_type=self.channel_type) total_cells, total_features = x.shape min_counts, max_counts = self.prepCounts(x) # Determine whether we are dealing with cells or genes basis = total_cells if self._FILTER_TARGET == "genes" else total_features other_name = "cells" if self._FILTER_TARGET == "genes" else "genes" opts = { "min_counts": min_counts, "max_counts": max_counts, f"min_{other_name}": get_count(self.min_genes_or_cells, basis), f"max_{other_name}": get_count(self.max_genes_or_cells, basis), } subset_ind, _ = self._filter_func(x, inplace=False, **opts) if self.key_n_counts is not None: self.logger.warning(f"{self.key_n_counts} will be added to the data") if self._FILTER_TARGET == "genes": if self.key_n_counts is not None: data.data.var[self.key_n_counts] = np.sum(x, axis=0) if self.key_n_genes_or_cells is not None: data.data.var[self.key_n_genes_or_cells] = np.sum(x > 0, axis=0) else: if self.key_n_counts is not None: data.data.obs[self.key_n_counts] = np.sum(x, axis=1) if self.key_n_genes_or_cells is not None: data.data.obs[self.key_n_genes_or_cells] = np.sum(x > 0, axis=1) if not subset_ind.all(): subset_func = getattr(data.data, self._subsetting_func_name) self.logger.info(f"Subsetting {self._FILTER_TARGET} ({~subset_ind.sum():,} removed) due to {self}") if self.inplace: if self._FILTER_TARGET == "genes": subset_func(subset_ind) else: data.filter_by_mask(subset_ind) # subset_func(subset_ind) else: if self._FILTER_TARGET == "genes": data.data.obsm[self.out] = x[:, subset_ind] else: data.data.varm[self.out] = x[:, subset_ind].T def prepCounts(self, x): if (isinstance(self.min_counts, float) and 0 < self.min_counts < 1) or (isinstance(self.max_counts, float) and 0 < self.max_counts < 1): min_counts = None max_counts = None if self._FILTER_TARGET == "genes": n_counts = np.sum(x, axis=0) elif self._FILTER_TARGET == "cells": n_counts = np.sum(x, axis=1) if isinstance(self.min_counts, float) and 0 <= self.min_counts <= 1: min_counts = np.percentile(n_counts, self.min_counts * 100) else: max_counts = np.percentile(n_counts, self.max_counts * 100) return min_counts, max_counts else: return self.min_counts, self.max_counts
[docs]@register_preprocessor("filter", "cell") class FilterCellsScanpy(FilterScanpy): """Scanpy filtering cell transformation with additional options. Allow passing gene counts as ratio Parameters ---------- min_counts Minimum number of counts required for a cell to be kept. min_genes Minimum number (or ratio) of genes required for a cell to be kept. max_counts Maximum number of counts required for a cell to be kept. max_genes Maximum number (or ratio) of genes required for a cell to be kept. split_name Which split to be used for filtering. channel Channel to be used for filtering. channel_type Channel type to be used for filtering. key_n_counts The location to add n_counts(the total counts for each cell). If it is None, it will not be added. key_n_genes The location to add n_genes(the number of genes expressed for each cell). If it is None, it will not be added. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in varm """ _DISPLAY_ATTRS = ("min_counts", "min_genes", "max_counts", "max_genes", "split_name") _FILTER_TARGET = "cells" def __init__( self, min_counts: Optional[Union[float, int]] = None, min_genes: Optional[Union[float, int]] = None, max_counts: Optional[Union[float, int]] = None, max_genes: Optional[Union[float, int]] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", key_n_counts: Optional[str] = None, key_n_genes: Optional[str] = None, inplace=True, **kwargs, ): super().__init__( min_counts=min_counts, min_genes_or_cells=min_genes, max_counts=max_counts, max_genes_or_cells=max_genes, split_name=split_name, channel=channel, channel_type=channel_type, key_n_counts=key_n_counts, key_n_genes_or_cells=key_n_genes, inplace=inplace, **kwargs, )
[docs]@register_preprocessor("filter", "gene") class FilterGenesScanpy(FilterScanpy): """Scanpy filtering gene transformation with additional options. Parameters ---------- min_counts Minimum number of counts required for a gene to be kept. min_cells Minimum number (or ratio) of cells required for a gene to be kept. max_counts Maximum number of counts required for a gene to be kept. max_cells Maximum number (or ratio) of cells required for a gene to be kept. split_name Which split to be used for filtering. channel Channel to be used for filtering. channel_type Channel type to be used for filtering. key_n_counts The location to add n_counts(the total counts for each gene). If it is None, it will not be added. key_n_cells The location to add n_cells(the number of cells expressed for each gene). If it is None, it will not be added. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in obsm """ _DISPLAY_ATTRS = ("min_counts", "min_cells", "max_counts", "max_cells", "split_name") _FILTER_TARGET = "genes" def __init__( self, min_counts: Optional[Union[float, int]] = None, min_cells: Optional[Union[float, int]] = None, max_counts: Optional[Union[float, int]] = None, max_cells: Optional[Union[float, int]] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", key_n_counts: Optional[str] = None, key_n_cells: Optional[str] = None, inplace=True, **kwargs, ): super().__init__(min_counts=min_counts, min_genes_or_cells=min_cells, max_counts=max_counts, max_genes_or_cells=max_cells, split_name=split_name, channel=channel, channel_type=channel_type, key_n_counts=key_n_counts, key_n_genes_or_cells=key_n_cells, inplace=inplace, **kwargs)
[docs]@register_preprocessor("filter", "cell") @add_mod_and_transform class FilterCellsCommonMod(BaseTransform): """Initialize the FilterCellsCommonMod class. Parameters ---------- mod1 : str Name of the first modality in the single-cell dataset. mod2 : str Name of the second modality in the single-cell dataset. sol : Optional[str], default=None Name of the optional solution dataset containing cell labels or annotations. **kwargs : dict Additional keyword arguments passed to the base transformation class. """ def __init__(self, mod1: str, mod2: str, sol: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.mod1 = mod1 self.mod2 = mod2 self.sol = sol def __call__(self, data: Data): md_data = data.data data_mod1 = md_data.mod[self.mod1] data_mod2 = md_data.mod[self.mod2] common_cells = list(set(data_mod1.obs.index) & set(data_mod2.obs.index)) data_mod1 = data_mod1[common_cells, :] data_mod2 = data_mod2[common_cells, :] data.data.mod[self.mod1] = data_mod1 data.data.mod[self.mod2] = data_mod2 if self.sol is not None: test_sol = md_data.mod[self.sol] test_sol = test_sol[common_cells, :] data.data.mod[self.sol] = test_sol
[docs]@register_preprocessor("filter", "gene") class FilterGenesCommon(BaseTransform): """Filter genes by taking the common genes across batches or splits. Parameters ---------- batch_key Which column in the ``.obs`` table to be used to distinguishing batches. split_keys A list of split names, e.g., 'train', to be used to find common gnees. Note ---- One and only one of :attr:`batch_key` or :attr:`split_keys` can be specified. """ _DISPLAY_ATTRS = ("batch_key", "split_keys") def __init__(self, batch_key: Optional[str] = None, split_keys: Optional[List[str]] = None, **kwargs): super().__init__(**kwargs) if (batch_key is not None) and (split_keys is not None): raise ValueError("Either batch_key or split_keys can be specified, but not both. " f"Got {batch_key=!r}, {split_keys=!r}") elif (batch_key is None) and (split_keys is None): raise ValueError("Either one of batch_key or split_keys must be specified.") self.batch_key = batch_key self.split_keys = split_keys def _select_by_splits(self, data) -> Dict[Union[str, int], ad.AnnData]: sliced_data_dict = {} for split_key in self.split_keys: idx = data.get_split_idx(split_key, error_on_miss=True) sliced_data_dict[split_key] = data.data[idx] return sliced_data_dict def _select_by_batch(self, data) -> Dict[Union[str, int], ad.AnnData]: sliced_data_dict = {} for batch_id, group in data.data.obs.groupby(self.batch_key): sliced_data_dict[batch_id] = data.data[group.index] return sliced_data_dict def __call__(self, data): if self.batch_key is None: sliced_data_dict = self._select_by_splits(data) elif self.split_keys is None: sliced_data_dict = self._select_by_batch(data) else: raise DevError("Both batch_key and split_keys are not set. This should have been caught at init.") all_genes = data.data.var_names.tolist() sub_genes_list = [] for name, sliced_data in sliced_data_dict.items(): x = sliced_data.X abs_sum = np.array(np.abs(x).sum(0)).ravel() hits = np.where(abs_sum > 0)[0] sub_genes = [all_genes[i] for i in hits] sub_genes_list.append(sub_genes) self.logger.info(f"{len(sub_genes):,} genes found in {name!r}") common_genes = sorted(set.intersection(*map(set, sub_genes_list))) self.logger.info(f"Found {len(common_genes):,} common genes out of {len(all_genes):,} total genes.") data.data._inplace_subset_var(common_genes)
[docs]@register_preprocessor("filter", "gene") class FilterGenesMatch(BaseTransform): """Filter genes based on prefixes and suffixes. Parameters ---------- prefixes List of prefixes to remove. suffixes List of suffixes to remove. """ _DISPLAY_ATTRS = ("prefixes", "suffixes") def __init__( self, prefixes: Optional[List[str]] = None, suffixes: Optional[List[str]] = None, case_sensitive: bool = False, **kwargs, ): super().__init__(**kwargs) self.prefixes = prefixes or [] self.suffixes = suffixes or [] self.case_sensitive = case_sensitive if case_sensitive: self.prefixes = [i.upper() for i in self.prefixes] self.suffixes = [i.upper() for i in self.suffixes] def __call__(self, data): indicator = np.zeros(data.shape[1], dtype=bool) for name, items in zip(["prefix", "suffix"], [self.prefixes, self.suffixes]): for item in items: ids = data.data.var_names.str if self.case_sensitive: ids = ids.upper().str new_indicator = ids.startswith(item) if name == "prefix" else ids.endswith(item) self.logger.info(f"{new_indicator.sum()} number of genes will be removed due to {name} {item!r}") indicator = np.logical_or(indicator, new_indicator) self.logger.info(f"Removing {indicator.sum()} genes in total") self.logger.debug(f"Removing genes: {data.data.var_names[indicator]}") data.data._inplace_subset_var(data.data.var_names[~indicator]) return data
[docs]class FilterGenes(BaseTransform, ABC): """Filter genes based on the summarized gene expressions.""" def __init__( self, *, mode: GeneSummaryMode = "sum", channel: Optional[str] = None, channel_type: Optional[str] = None, whitelist_indicators: Optional[Union[str, List[str]]] = None, add_n_counts=True, add_n_cells=True, inplace=True, **kwargs, ): super().__init__(**kwargs) if (channel is not None) and (channel_type != "layers"): raise ValueError(f"Only X layers is available for filtering genes, specified {channel_type=!r}") if mode not in (all_modes := sorted(get_args(GeneSummaryMode))): raise ValueError(f"Unknown summarization mode {mode!r}, available options are {all_modes}") self.mode = mode self.channel = channel self.channel_type = channel_type self.whitelist_indicators = whitelist_indicators self.add_n_counts = add_n_counts self.add_n_cells = add_n_cells self.inplace = inplace def _get_preserve_mask(self, gene_summary: np.ndarray) -> np.ndarray: """Select gene to be preserved and return as a mask.""" ... def __call__(self, data): x = data.get_feature(return_type="numpy", channel=self.channel, channel_type=self.channel_type) if self.add_n_counts: self.logger.warning(f"n_counts will be added to the var of data") data.data.var["n_counts"] = np.sum(x, axis=0) if self.add_n_cells: self.logger.warning(f"n_cells will be added to the var of data") data.data.var["n_cells"] = np.sum(x > 0, axis=0) # Compute gene summary stats for filtering if self.mode == "sum": gene_summary = np.array(x.sum(0)).ravel() elif self.mode == "var": x_squared = x.power(2) if isinstance(x, sp.spmatrix) else x**2 gene_summary = np.array(x_squared.mean(0) - np.square(x.mean(0))).ravel() elif self.mode == "cv": gene_summary = np.nan_to_num(np.array(x.std(0) / x.mean(0)), posinf=0, neginf=0).ravel() elif self.mode == "rv": gene_summary = np.nan_to_num(np.array(x.var(0) / x.mean(0)), posinf=0, neginf=0).ravel() else: raise DevError(f"{self.mode!r} not expected, please inform dev to fix this error.") self.logger.info(f"Filtering genes based on {self.mode} expression percentiles in layer {self.channel!r}") mask = self._get_preserve_mask(gene_summary) selected_genes = sorted(data.data.var_names[mask]) # Get whitelist genes to be excluded from the filtering process whitelist_gene_set = set() if self.whitelist_indicators is not None: columns = self.whitelist_indicators columns = [columns] if isinstance(columns, str) else columns indicators = data.data.var[columns] # Genes that satisfy any one of the whitelist conditions will be selected as whitelist genes whitelist_gene_set.update(indicators[indicators.max(1)].index.tolist()) # Exclude whitelisted genes if len(whitelist_gene_set) > 0: orig_num_selected = len(selected_genes) selected_genes = sorted(set(selected_genes) | whitelist_gene_set) num_added = len(selected_genes) - orig_num_selected self.logger.info(f"{num_added:,} genes originally unselected are being added due to whitelist") data.data.uns['gene_summary'] = gene_summary # Update data self.logger.info(f"{data.shape[1] - len(selected_genes):,} genes removed") if self.inplace: data.data._inplace_subset_var(selected_genes) else: data.data.obsm[self.out] = data.data[:, selected_genes].X
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class FilterGenesPercentile(FilterGenes): """Filter genes based on percentiles of the summarized gene expressions. Parameters ---------- min_val Minimum percentile of the summarized expression value below which the genes will be discarded. max_val Maximum percentile of the summarized expression value above which the genes will be discarded. mode Summarization mode. Available options are ``[sum|var|cv|rv]``. ``sum`` calculates the sum of expression values, ``var`` calculates the variance of the expression values, ``cv`` uses the coefficient of variation (std / mean ), and ``rv`` uses the relative variance (var / mean). channel Which channel, more specificailly, ``layers``, to use. Use the default ``.X`` if not set. If ``channel`` is specified, then need to specify ``channel_type`` to be ``layers`` as well. channel_type Type of channels specified. Only allow ``None`` (the default setting) or ``layers`` (when ``channel`` is specified). whitelist_indicators A list of (or a single) :obj:`.var` columns that indicates the genes to be excluded from the filtering process. Note that these genes will still be used in the summary stats computation, and thus will still contribute to the threshold percentile. If not set, then no genes will be excluded from the filtering process. add_n_counts Whether to add ``n_counts``, the total counts for each gene. add_n_cells Whether to add ``n_cells``, the number of cells expressed for each gene. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in obsm """ _DISPLAY_ATTRS = ("min_val", "max_val", "mode") def __init__( self, min_val: Optional[float] = 1, max_val: Optional[float] = 99, *, mode: GeneSummaryMode = "sum", channel: Optional[str] = None, channel_type: Optional[str] = None, whitelist_indicators: Optional[Union[str, List[str]]] = None, add_n_counts=True, add_n_cells=True, inplace=True, **kwargs, ): super().__init__( mode=mode, channel=channel, channel_type=channel_type, whitelist_indicators=whitelist_indicators, add_n_counts=add_n_counts, add_n_cells=add_n_cells, inplace=inplace, **kwargs, ) self.min_val = min_val self.max_val = max_val def _get_preserve_mask(self, gene_summary): percentile_lo = np.percentile(gene_summary, self.min_val) percentile_hi = np.percentile(gene_summary, self.max_val) return np.logical_and(gene_summary >= percentile_lo, gene_summary <= percentile_hi)
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class FilterGenesTopK(FilterGenes): """Select top/bottom genes based on the summarized gene expressions. Parameters ---------- num_genes Number of genes to be selected. top If set to :obj:`True`, then use the genes with highest values of the specified gene summary stats. mode Summarization mode. Available options are ``[sum|var|cv|rv]``. ``sum`` calculates the sum of expression values, ``var`` calculates the variance of the expression values, ``cv`` uses the coefficient of variation (std / mean ), and ``rv`` uses the relative variance (var / mean). channel Which channel, more specificailly, ``layers``, to use. Use the default ``.X`` if not set. If ``channel`` is specified, then need to specify ``channel_type`` to be ``layers`` as well. channel_type Type of channels specified. Only allow ``None`` (the default setting) or ``layers`` (when ``channel`` is specified). whitelist_indicators A list of (or a single) :obj:`.var` columns that indicates the genes to be excluded from the filtering process. Note that these genes will still be used in the summary stats computation, and thus will still contribute to the threshold percentile. If not set, then no genes will be excluded from the filtering process. add_n_counts Whether to add ``n_counts``, the total counts for each gene. add_n_cells Whether to add ``n_cells``, the number of cells expressed for each gene. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in obsm """ _DISPLAY_ATTRS = ("num_genes", "top", "mode") def __init__( self, num_genes: int = 1000, top: bool = True, *, mode: GeneSummaryMode = "cv", channel: Optional[str] = None, channel_type: Optional[str] = "X", whitelist_indicators: Optional[Union[str, List[str]]] = None, add_n_counts=False, add_n_cells=False, inplace=True, **kwargs, ): super().__init__( mode=mode, channel=channel, channel_type=channel_type, whitelist_indicators=whitelist_indicators, add_n_counts=add_n_counts, add_n_cells=add_n_cells, inplace=inplace, **kwargs, ) self.num_genes = num_genes self.top = top def _get_preserve_mask(self, gene_summary): total_num_genes = gene_summary.size if self.num_genes >= total_num_genes: # raise ValueError(f"{self.num_genes=!r} > total number of genes: {total_num_genes}") self.logger.warning(f"{self.num_genes=!r} > total number of genes: {total_num_genes}") self.num_genes = total_num_genes sorted_idx = gene_summary.argsort() selected_idx = sorted_idx[-self.num_genes:] if self.top else sorted_idx[:self.num_genes] mask = np.zeros(total_num_genes, dtype=bool) mask[selected_idx] = True return mask
[docs]@register_preprocessor("filter", "gene") class FilterGenesMarker(BaseTransform): """Select marker genes based on log fold-change. Parameters ---------- ct_profile_channel Name of the ``.varm`` channel that contains the cell-topic profile which will be used to compute the log fold-changes for each cell-topic (e.g., cell type). subset If set to :obj:`True`, then inplace subset the variables to only contain the markers. label If set, e.g., to :obj:`'marker'`, then save the marker indicator to the :obj:`.obs` column named as :obj:`marker`. threshold Threshold value of the log fol-change above which the gene will be considered as a marker. eps A small value that prevents taking log of zeros. """ _DISPLAY_ATTRS = ("ct_profile_channel", "subset", "threshold", "eps") def __init__( self, *, ct_profile_channel: str = "CellTopicProfile", subset: bool = True, label: Optional[str] = None, threshold: float = 1.25, eps: float = 1e-6, **kwargs, ): super().__init__(**kwargs) self.ct_profile_channel = ct_profile_channel self.subset = subset self.label = label self.threshold = threshold self.eps = eps @staticmethod def get_marker_genes( ct_profile: np.ndarray, # gene x celltype cell_types: List[str], genes: List[str], *, threshold: float = 1.25, eps: float = 1e-6, logger: Logger = default_logger, ) -> Tuple[List[str], pd.DataFrame]: if (num_cts := len(cell_types)) < 2: raise ValueError(f"Need at least two cell types to find marker genes, got {num_cts}:\n{cell_types}") # Find marker genes for each cell type marker_gene_ind_df = pd.DataFrame(False, index=genes, columns=cell_types) for i, ct in enumerate(cell_types): others = [j for j in range(num_cts) if j != i] log_fc = np.log(ct_profile[:, i] + eps) - np.log(ct_profile[:, others].mean(1) + eps) markers_idx = np.where(log_fc > threshold)[0] if markers_idx.size > 0: marker_gene_ind_df.iloc[markers_idx, i] = True markers = marker_gene_ind_df.iloc[markers_idx].index.tolist() logger.info(f"Found {len(markers):,} marker genes for cell type {ct!r}") logger.debug(f"{markers=}") else: logger.info(f"No marker genes found for cell type {ct!r}") # Combine all marker genes is_marker = marker_gene_ind_df.max(1) marker_genes = is_marker[is_marker].index.tolist() logger.info(f"Total number of marker genes found: {len(marker_genes):,}") logger.debug(f"{marker_genes=}") return marker_genes, marker_gene_ind_df def __call__(self, data): ct_profile_df = data.get_feature(channel=self.ct_profile_channel, channel_type="varm", return_type="default") ct_profile = ct_profile_df.values cell_types = ct_profile_df.columns.tolist() genes = ct_profile_df.index.tolist() marker_genes, marker_gene_ind_df = self.get_marker_genes(ct_profile, cell_types, genes, eps=self.eps, threshold=self.threshold, logger=self.logger) # Save marker gene info to data data.data.varm[self.out] = marker_gene_ind_df if self.label is not None: data.data.var[self.label] = marker_gene_ind_df.max(1) if self.subset: # inplace subset the variables data.data._inplace_subset_var(marker_genes)
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class FilterGenesRegression(BaseTransform): """Select genes based on regression. Parameters ---------- method What regression based gene selection methtod to use. Supported options are: ``"enclasc"``, ``"seurat3"``, and ``"scmap"``. num_genes Number of genes to select. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in obsm Note ---- The implementation is adapted from the EnClaSC GitHub repo: https://github.com/xy-chen16/EnClaSC References --------- https://bmcbioinformatics.biomedcentral.com/articles/10.1186/s12859-020-03679-z """ _DISPLAY_ATTRS = ("num_genes", ) def __init__(self, method: str = "enclasc", num_genes: int = 1000, *, channel: Optional[str] = None, channel_type: Optional[str] = None, mod: Optional[str] = None, skip_count_check: bool = False, inplace=True, **kwargs): super().__init__(**kwargs) self.num_genes = num_genes self.channel = channel self.method = method self.skip_count_check = skip_count_check self.inplace = inplace self.channel_type = channel_type def __call__(self, data): feat = data.get_feature(return_type="numpy", channel=self.channel, channel_type=self.channel_type) if not self.skip_count_check and np.mod(feat, 1).sum(): warnings.warn("Expecting count data as input, but the input feature matrix does not appear to be count." "Please make sure the input is indeed a count matrix.") func_dict = {"enclasc": self._filter_enclasc, "seurat3": self._filter_seurat3, "scmap": self._filter_scmap} if (filter_func := func_dict.get(self.method)) is None: raise ValueError(f"Unknown method {self.method}, supported options are: {list(func_dict)}.") if self.num_genes >= feat.shape[1]: self.logger.warning(f"{self.num_genes=!r} > total number of genes: {feat.shape[1]}") self.num_genes = feat.shape[1] # data.data.obsm[self.out] = filter_func(feat, self.num_genes) gene_names = data.data.var_names[filter_func(feat, self.num_genes)] if self.inplace: data.data._inplace_subset_var(gene_names) else: data.data.obsm[self.out] = data.data[:, gene_names].X return data def _filter_enclasc(self, feat: np.ndarray, num_genes: int = 2000, logger: Logger = default_logger, no_check: bool = False) -> np.ndarray: logger.info("Start generating cell features using EnClaSC") num_feat = feat.shape[1] scores = np.zeros(num_feat) - 100 feat_mean = feat.mean(0) drop_feat = (feat == 0).mean(0) select_index = (0 < drop_feat) & (drop_feat < 1) x1 = feat_mean[select_index].reshape(-1, 1) x2 = drop_feat[select_index].reshape(-1, 1) y = np.log(feat_mean + 1)[select_index].reshape(-1, 1) y_pred = LinearRegression(n_jobs=8).fit(x2, y).predict(x2) scores[select_index] = (2 * y - y_pred - x1).ravel() feat_index = np.argpartition(scores, -num_genes)[-num_genes:] return feat_index def _filter_seurat3(self, feat: np.ndarray, num_genes: int = 2000, logger: Logger = default_logger, no_check: bool = False) -> np.ndarray: logger.info("Start generating cell features using Seurat v3.0") feat_mean_log = np.log(feat.mean(0) + 1) feat_var_log = np.log(feat.var(0) + 1) x = PolynomialFeatures(degree=2).fit_transform(feat_mean_log.reshape(-1, 1)) y_pred = LinearRegression().fit(x, feat_var_log).predict(x) scores = (feat_var_log - y_pred).ravel() feat_index = np.argpartition(scores, -num_genes)[-num_genes:] return feat_index def _filter_scmap(self, feat: np.ndarray, num_genes: int = 2000, logger: Logger = default_logger, no_check: bool = False) -> np.ndarray: logger.info("Start generating cell features using scmap") num_feat = feat.shape[1] scores = np.zeros(num_feat) - 100 feat_mean = feat.mean(0) drop_feat = (feat == 0).mean(0) select_index = (0 < drop_feat) & (drop_feat < 1) x = np.log(feat_mean[select_index] + 1).reshape(-1, 1) * np.log(2.7) / np.log(2) y = np.log(drop_feat[select_index] * 100).reshape(-1, 1) * np.log(2.7) / np.log(2) y_pred = LinearRegression().fit(x, y).predict(x) scores[select_index] = (y - y_pred).ravel() feat_index = np.argpartition(scores, -num_genes)[-num_genes:] return feat_index
[docs]@register_preprocessor("filter", "gene") class FilterGenesMarkerGini(BaseTransform): """Select marker genes based on Gini coefficient. Identfy marker genes for all clusters in a one vs all manner based on Gini coefficients, a measure for inequality. Parameters ---------- ct_profile_channel Name of the ``.varm`` channel that contains the cell-topic profile which will be used to compute the log fold-changes for each cell-topic (e.g., cell type). ct_profile_detection_channel Name of the ``.varm`` channel that contains the cell-topic profile nums which greater than some value which will be used to compute the log fold-changes for each cell-topic (e.g., cell type). subset If set to :obj:`True`, then inplace subset the variables to only contain the markers. label If set, e.g., to :obj:`'marker'`, then save the marker indicator to the :obj:`.obs` column named as :obj:`marker`. References --------- https://genomebiology.biomedcentral.com/articles/10.1186/s13059-016-1010-4?ref=https://githubhelp.com """ def __init__( self, *, ct_profile_channel: str = "CellGiottoTopicProfile", ct_profile_detection_channel: str = "CellGiottoDetectionTopicProfile", subset: bool = True, label: Optional[str] = None, **kwargs, ): super().__init__(**kwargs) self.ct_profile_channel = ct_profile_channel self.ct_profile_detection_channel = ct_profile_detection_channel self.subset = subset self.label = label def __call__( self, data, logger: Logger = default_logger, ): ct_profile_df = data.get_feature(channel=self.ct_profile_channel, channel_type="varm", return_type="default") ct_profile_detection_df = data.get_feature(channel=self.ct_profile_detection_channel, channel_type="varm", return_type="default") cell_type_nums_df = data.get_feature(channel="CellTypeNums", channel_type="uns", return_type="default") ct_profile = ct_profile_df.values ct_profile_detection = ct_profile_detection_df.values cell_types = ct_profile_df.columns.tolist() genes = ct_profile_df.index.tolist() marker_gene_ind_df = pd.DataFrame(False, index=genes, columns=cell_types) ans_gene = [] if (num_cts := len(cell_types)) < 2: raise ValueError(f"Need at least two cell types to find marker genes, got {num_cts}:\n{cell_types}") for i, ct in enumerate(cell_types): other_ct_profile = np.zeros_like(ct_profile[:, i]) other_detection_ct_profile = np.zeros_like(ct_profile[:, i]) other_sum = 0 for j in range(num_cts): if j != i: other = cell_type_nums_df.loc[cell_types[j], "nums"] other_ct_profile += ct_profile[:, j] * other other_detection_ct_profile += ct_profile_detection[:, j] * other other_sum += other other_ct_profile = other_ct_profile / other_sum other_detection_ct_profile = other_detection_ct_profile / other_sum top_genes_scores_filtered = get_marker_genes_giotto(ct_profile[:, i], other_ct_profile, ct_profile_detection[:, i], other_detection_ct_profile, genes=genes) markers_idx = np.array(top_genes_scores_filtered.index) top_genes_scores_filtered["cellType"] = ct ans_gene.append(top_genes_scores_filtered) if markers_idx.size > 0: marker_gene_ind_df.iloc[markers_idx, i] = True markers = marker_gene_ind_df.iloc[markers_idx].index.tolist() logger.info(f"Found {len(markers):,} marker genes for cell type {ct!r}") logger.debug(f"{markers=}") else: logger.info(f"No marker genes found for cell type {ct!r}") # Combine all marker genes is_marker = marker_gene_ind_df.max(1) marker_genes = is_marker[is_marker].index.tolist() # Save marker gene info to data data.data.uns[self.out] = pd.concat(ans_gene, axis=0) if self.label is not None: data.data.var[self.label] = pd.concat(marker_gene_ind_df.max(1)) if self.subset: # inplace subset the variables data.data._inplace_subset_var(marker_genes)
def get_marker_genes_giotto(group1, group2, group_detection_1, group_detection_2, min_expr_gini_score=0.2, min_det_gini_score=0.2, rank_score=1, min_genes=5, genes=None): gene_nums = group1.shape[0] gene_detection_gini_score = np.zeros((2, gene_nums)) gene_gini_score = np.zeros((2, gene_nums)) gene_rank_score = np.zeros((2, gene_nums)) expressions = np.zeros((2, gene_nums)) detections = np.zeros((2, gene_nums)) gene_detection_rank_score = np.zeros((2, gene_nums)) scaler = MinMaxScaler(feature_range=(0.1, 1)) # inverse for i in range(gene_nums): gene_gini_score[:, i] = [gini_func([group1[i], group2[i]])] * 2 expressions[:, i] = [group1[i], group2[i]] gene_detection_gini_score[:, i] = [gini_func([group_detection_1[i], group_detection_2[i]])] * 2 detections[:, i] = [group_detection_1[i], group_detection_2[i]] gene_rank_score[:, i] = rankdata(np.array([group1[i], group2[i]])) # inverse gene_detection_rank_score[:, i] = rankdata(np.array([group_detection_1[i], group_detection_2[i]])) gene_rank_score = scaler.fit_transform(gene_rank_score) gene_detection_rank_score = scaler.fit_transform(gene_detection_rank_score) ans_score = (gene_detection_gini_score * gene_gini_score * gene_rank_score * gene_detection_rank_score)[0] ans_rank = np.argsort(np.argsort(-ans_score)) + 1 ans_df = pd.DataFrame( False, index=[i for i in range(gene_nums)], columns=[ "ans_score", "ans_rank", "expression", "detection", "expression_gini", "detection_gini", "gene_rank_score", "gene_detection_rank_score", "gene_name" ]) ans_df.loc[:, ["ans_score"]] = ans_score ans_df.loc[:, ["ans_rank"]] = ans_rank ans_df.loc[:, ["expression"]] = expressions[0] ans_df.loc[:, ["detection"]] = detections[0] ans_df.loc[:, ["expression_gini"]] = gene_gini_score[0] ans_df.loc[:, ["detection_gini"]] = gene_detection_gini_score[0] ans_df.loc[:, ["gene_rank_score"]] = gene_rank_score[0] ans_df.loc[:, ["gene_detection_rank_score"]] = gene_detection_rank_score[0] ans_df.loc[:, ["gene_name"]] = genes # Filter on combined rank or individual ranks top_genes_scores = ans_df[(ans_df['ans_rank'] <= min_genes) | (ans_df['gene_rank_score'] <= rank_score) & (ans_df['gene_detection_rank_score'] <= rank_score)] # Further filter on expression and detection gini score top_genes_scores_filtered = top_genes_scores[(top_genes_scores['ans_rank'] <= min_genes) | (top_genes_scores['expression'] > min_expr_gini_score) & (top_genes_scores['detection'] > min_det_gini_score)] return top_genes_scores_filtered def gini_func(x, weights=None): if weights is None: weights = np.ones(len(x)) dataset = np.column_stack((x, weights)) ord_x = np.argsort(x) dataset_ord = dataset[ord_x] x = dataset_ord[:, 0] weights = dataset_ord[:, 1] N = np.sum(weights) xw = x * weights C_i = np.cumsum(weights) num_1 = np.sum(xw * C_i) num_2 = np.sum(xw) num_3 = np.sum(xw * weights) G_num = (2 / N**2) * num_1 - (1 / N) * num_2 - (1 / N**2) * num_3 t_neg = xw[xw <= 0] T_neg = np.sum(t_neg) T_pos = np.sum(xw) + np.abs(T_neg) n_RSV = 2 * (T_pos + np.abs(T_neg)) / N mean_RSV = n_RSV / 2 G_RSV = G_num / mean_RSV return G_RSV
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class FilterGenesScanpyOrder(BaseTransform): """Scanpy filtering gene transformation with additional options. Parameters ---------- order Order of (min_counts, min_cells, max_counts, max_cells). For example, ``["min_counts", "min_cells", "max_counts", "max_cells"]`` or ``["max_counts", "min_cells"]``. If not set, will be set by default to ``["min_counts", "min_cells", "max_counts", "max_cells"]``. min_counts Minimum number of counts required for a gene to be kept. min_cells Minimum number (or ratio) of cells required for a gene to be kept. max_counts Maximum number of counts required for a gene to be kept. max_cells Maximum number (or ratio) of cells required for a gene to be kept. split_name Which split to be used for filtering. channel Channel to be used for filtering. channel_type Channel type to be used for filtering. add_n_counts Whether to add ``n_counts``, the total counts for each gene. add_n_cells Whether to add ``n_cells``, the number of cells expressed for each gene. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in obsm """ def __init__( self, order: Optional[List[str]] = None, min_counts: Optional[Union[float, int]] = None, min_cells: Optional[Union[float, int]] = None, max_counts: Optional[Union[float, int]] = None, max_cells: Optional[Union[float, int]] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", add_n_counts=True, add_n_cells=True, inplace=True, params_dict=None, **kwargs, ): super().__init__(**kwargs) self.filter_genes_order = default( order, ["min_counts", "min_cells", "max_counts", "max_cells"], ) self.logger.info(f"Filter genes order: {self.filter_genes_order}") # if load_params: geneParameterDict = { "min_counts": min_counts, "min_cells": min_cells, "max_counts": max_counts, "max_cells": max_cells } self.add_n_counts = add_n_counts self.add_n_cells = add_n_cells if not set(self.filter_genes_order).issubset(set(geneParameterDict.keys())): raise KeyError(f"An order should be in {geneParameterDict.keys()}") self.geneScanpyOrderDict = {} for key in geneParameterDict.keys(): if key in self.filter_genes_order: if key in self.filter_genes_order: key_n_counts = ("n_counts" if self.add_n_counts else None) key_n_cells = ("n_cells" if self.add_n_cells else None) self.geneScanpyOrderDict[key] = FilterGenesScanpy( **{key: geneParameterDict[key]}, split_name=split_name, channel=channel, channel_type=channel_type, key_n_counts=key_n_counts, key_n_cells=key_n_cells, inplace=inplace, **kwargs, ) else: self.logger.warning(f"{key} not in order,It makes no sense to set {key}") def __call__(self, data: Data): for parameter in self.filter_genes_order: geneScanpyOrder = self.geneScanpyOrderDict[parameter] geneScanpyOrder(data)
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class HighlyVariableGenesRawCount(AnnDataTransform): """Filter for highly variable genes using raw count matrix. Parameters ---------- layer If provided, then use `data.data.layers[layer]` for expression values instead of the default ``data.data.X``. n_top_genes Number of highly-variable genes to keep. span The fraction of the data (cells) used when estimating the variance in the loess model fit if `flavor="seurat_v3"`. subset Inplace subset to highly-variable genes if `True` otherwise merely indicate highly variable genes. inplace Whether to place calculated metrics in `.var` or return them. batch_key If specified, highly-variable genes are selected within each batch separately and merged. This simple process avoids the selection of batch-specific genes and acts as a lightweight batch correction method. For all flavors, genes are first sorted by how many batches they are a HVG. For dispersion-based flavors ties are broken by normalized dispersion. If `flavor = "seurat_v3"`, ties are broken by the median (across batches) rank based on within-batch normalized variance. check_values Check if counts in selected layer are integers. A Warning is returned if set to True. Only used if `flavor="seurat_v3"`. See also -------- This is a wrapper for https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html """ def __init__(self, channel: Optional[str] = None, channel_type: Optional[str] = None, n_top_genes: Optional[int] = 1000, span: Optional[float] = 0.3, subset: bool = True, inplace: bool = True, batch_key: Optional[str] = None, check_values: bool = True, **kwargs): layer = channel if channel_type == "layers" else None super().__init__(sc.pp.highly_variable_genes, layer=layer, n_top_genes=n_top_genes, batch_key=batch_key, check_values=check_values, span=span, subset=subset, inplace=inplace, flavor="seurat_v3", **kwargs) self.logger.info("Expects count data") def __call__(self, data): adata = data.data if adata.X.shape[1] == 0: raise ValueError("Gene dimension is 0") #Prevent kernel crash return super().__call__(data)
# def is_integer_sample(self, arr, sample_ratio=0.01): # """Check if the data is an integer.""" # if isinstance(arr, np.ndarray): # # numpy array # positive_indices = np.where(arr > 0) # data = arr[positive_indices] # elif isinstance(arr, (csr_matrix, csc_matrix)): # nonzero_indices = arr.nonzero() # data = arr[nonzero_indices].data # else: # raise TypeError(f"Input must be either a numpy array or a csr_matrix,should not be {type(arr)}") # sample_size = int(len(data) * sample_ratio) # if sample_size == 0: # is_integer = np.all(np.equal(np.mod(data, 1), 0)) # else: # sample_indices = np.random.choice(len(data), size=sample_size, replace=False) # sample = data[sample_indices] # is_integer = np.all(np.equal(np.mod(sample, 1), 0)) # return is_integer
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class HighlyVariableGenesLogarithmizedByTopGenes(AnnDataTransform): """Filter for highly variable genes based on top genes. Parameters ---------- layer If provided, then use data.data.layers[layer]` for expression values instead of the default `data.data.X`. n_top_genes Number of highly-variable genes to keep. n_bins Number of bins for binning the mean gene expression. Normalization is done with respect to each bin. If just a single gene falls into a bin, the normalized dispersion is artificially set to 1. You'll be informed about this if you set `settings.verbosity = 4`. flavor Choose the flavor for identifying highly variable genes. For the dispersion based methods in their default workflows, Seurat passes the cutoffs whereas Cell Ranger passes `n_top_genes`. subset Inplace subset to highly-variable genes if `True` otherwise merely indicate highly variable genes. inplace Whether to place calculated metrics in `.var` or return them. batch_key If specified, highly-variable genes are selected within each batch separately and merged. This simple process avoids the selection of batch-specific genes and acts as a lightweight batch correction method. For all flavors, genes are first sorted by how many batches they are a HVG. For dispersion-based flavors ties are broken by normalized dispersion. If `flavor = "seurat_v3"`, ties are broken by the median (across batches) rank based on within-batch normalized variance. See also -------- This is a wrapper for https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html """ def __init__(self, channel: Optional[str] = None, channel_type: Optional[str] = None, n_top_genes: Optional[int] = 1000, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", subset: bool = True, inplace: bool = True, batch_key: Optional[str] = None, **kwargs): layer = channel if channel_type == "layers" else None super().__init__(sc.pp.highly_variable_genes, layer=layer, n_top_genes=n_top_genes, n_bins=n_bins, flavor=flavor, subset=subset, inplace=inplace, batch_key=batch_key, **kwargs) self.logger.info("Expects logarithmized data")
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform # @deprecated(msg="will be replaced by builtin bypass mechanism in pipeline") class FilterGenesPlaceHolder(BaseTransform): """Used as a placeholder to skip the process.""" def __init__(self, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", add_n_counts=True, add_n_cells=True, inplace=True, **kwargs): super().__init__(**kwargs) self.split_name = split_name self.channel = channel self.channel_type = channel_type self.add_n_counts = add_n_counts self.add_n_cells = add_n_cells self.inplace = inplace def __call__(self, data: Data) -> Data: x = data.get_feature(return_type="numpy", split_name=self.split_name, channel=self.channel, channel_type=self.channel_type) n_counts = np.sum(x, axis=0) n_cells = np.sum(x > 0, axis=0) if self.add_n_counts: self.logger.warning(f"n_counts will be added to the var of data") data.data.var["n_counts"] = n_counts if self.add_n_cells: self.logger.warning(f"n_cells will be added to the var of data") data.data.var["n_cells"] = n_cells if not self.inplace: data.data.obsm[self.out] = x return data
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform # @deprecated(msg="will be replaced by builtin bypass mechanism in pipeline") class FilterGenesNumberPlaceHolder(BaseTransform): def __init__(self, channel=None, channel_type=None, **kwargs): super().__init__(**kwargs) def __call__(self, data: Data) -> Data: return data
[docs]@register_preprocessor("filter", "gene") @add_mod_and_transform class HighlyVariableGenesLogarithmizedByMeanAndDisp(AnnDataTransform): """Filter for highly variable genes based on mean and dispersion. Parameters ---------- layer If provided, then use data.data.layers[layer]` for expression values instead of the default `data.data.X`. min_mean min_mean max_mean max_mean min_disp min_disp max_disp max_disp n_bins Number of bins for binning the mean gene expression. Normalization is done with respect to each bin. If just a single gene falls into a bin, the normalized dispersion is artificially set to 1. You'll be informed about this if you set `settings.verbosity = 4`. flavor Choose the flavor for identifying highly variable genes. For the dispersion based methods in their default workflows, Seurat passes the cutoffs whereas Cell Ranger passes `n_top_genes`. subset Inplace subset to highly-variable genes if `True` otherwise merely indicate highly variable genes. inplace Whether to place calculated metrics in `.var` or return them. batch_key If specified, highly-variable genes are selected within each batch separately and merged. This simple process avoids the selection of batch-specific genes and acts as a lightweight batch correction method. For all flavors, genes are first sorted by how many batches they are a HVG. For dispersion-based flavors ties are broken by normalized dispersion. If `flavor = "seurat_v3"`, ties are broken by the median (across batches) rank based on within-batch normalized variance. See also -------- This is a wrapper for https://scanpy.readthedocs.io/en/stable/generated/scanpy.pp.highly_variable_genes.html """ def __init__(self, channel: Optional[str] = None, channel_type: Optional[str] = None, min_disp: Optional[float] = 0.5, max_disp: Optional[float] = np.inf, min_mean: Optional[float] = 0.0125, max_mean: Optional[float] = 3, n_bins: int = 20, flavor: Literal["seurat", "cell_ranger"] = "seurat", subset: bool = True, inplace: bool = True, batch_key: Optional[str] = None, **kwargs): layer = channel if channel_type == "layers" else None super().__init__(sc.pp.highly_variable_genes, layer=layer, min_disp=min_disp, max_disp=max_disp, min_mean=min_mean, max_mean=max_mean, n_bins=n_bins, flavor=flavor, subset=subset, inplace=inplace, batch_key=batch_key, **kwargs) self.logger.info("Expects logarithmized data")
[docs]@register_preprocessor("filter", "cell") @add_mod_and_transform # @deprecated(msg="will be replaced by builtin bypass mechanism in pipeline") class FilterCellsPlaceHolder(BaseTransform): """Used as a placeholder to skip the process.""" def __init__(self, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", add_n_counts=True, add_n_genes=True, inplace=True, **kwargs): super().__init__(**kwargs) self.split_name = split_name self.channel = channel self.channel_type = channel_type self.add_n_counts = add_n_counts self.add_n_genes = add_n_genes self.inplace = inplace def __call__(self, data: Data) -> Data: x = data.get_feature(return_type="numpy", split_name=self.split_name, channel=self.channel, channel_type=self.channel_type) n_counts = np.sum(x, axis=1) n_genes = np.sum(x > 0, axis=1) if self.add_n_counts: self.logger.warning(f"n_counts will be added to the obs of data") data.data.obs["n_counts"] = n_counts if self.add_n_genes: self.logger.warning(f"n_genes will be added to the obs of data") data.data.obs["n_genes"] = n_genes if not self.inplace: data.data.varm[self.out] = x.T return data
[docs]@register_preprocessor("filter", "cell") @add_mod_and_transform class FilterCellsScanpyOrder(BaseTransform): """Scanpy filtering cell transformation with additional options. Allow passing gene counts as ratio Parameters ---------- order Order of (min_counts, min_cells, max_counts, max_cells). For example, ``["min_counts", "min_genes", "max_counts", "max_genes"]`` or ``["max_counts", "min_genes"]``. If not set, will be set by default to ``["min_counts", "min_genes", "max_counts", "max_genes"]``. min_counts Minimum number of counts required for a cell to be kept. min_genes Minimum number (or ratio) of genes required for a cell to be kept. max_counts Maximum number of counts required for a cell to be kept. max_genes Maximum number (or ratio) of genes required for a cell to be kept. split_name Which split to be used for filtering. channel Channel to be used for filtering. channel_type Channel type to be used for filtering. add_n_counts Whether to add ``n_counts``, the total counts for each cell. add_n_genes Whether to add ``n_genes``, the number of genes expressed for each cell. inplace If inplace is True, the original data is replaced with the filtered data. If inplace is False, the filtered data is stored in varm """ def __init__(self, order: Optional[List[str]] = None, min_counts: Optional[Union[float, int]] = None, min_genes: Optional[Union[float, int]] = None, max_counts: Optional[Union[float, int]] = None, max_genes: Optional[Union[float, int]] = None, split_name: Optional[str] = None, channel: Optional[str] = None, channel_type: Optional[str] = "X", add_n_counts=True, add_n_genes=True, inplace=True, **kwargs): super().__init__(**kwargs) self.filter_cells_order = default(order, ["min_counts", "min_genes", "max_counts", "max_genes"]) self.logger.info(f"Filter cells order: {self.filter_cells_order}") cellParameterDict = { "min_counts": min_counts, "min_genes": min_genes, "max_counts": max_counts, "max_genes": max_genes } self.add_n_counts = add_n_counts self.add_n_genes = add_n_genes if not set(self.filter_cells_order).issubset(set(cellParameterDict.keys())): raise KeyError(f"An order should be in {cellParameterDict.keys()}") self.cellScanpyOrderDict = {} for key in cellParameterDict.keys(): if key in self.filter_cells_order: key_n_counts = ("n_counts" if self.add_n_counts else None) key_n_genes = ("n_genes" if self.add_n_genes else None) self.cellScanpyOrderDict[key] = FilterCellsScanpy(**{key: cellParameterDict[key]}, split_name=split_name, channel=channel, channel_type=channel_type, key_n_counts=key_n_counts, key_n_genes=key_n_genes, inplace=inplace, **kwargs) else: self.logger.warning(f"{key} not in order,It makes no sense to set {key}") def __call__(self, data: Data): for parameter in self.filter_cells_order: cellScanpyOrder = self.cellScanpyOrderDict[parameter] cellScanpyOrder(data)
[docs]@register_preprocessor("filter", "cell") @add_mod_and_transform class FilterCellsType(BaseTransform): #TODO not in search """Filter cell types based on the threshold.""" def __init__(self, cell_type_threshold=10, **kwargs): super().__init__(**kwargs) self.cell_type_threshold = cell_type_threshold def __call__(self, data: Data) -> Data: # adata_copied = data.data.copy() # cellType_Number = data.data.obsm["cell_type"].value_counts() # celltype_to_remove = cellType_Number[cellType_Number <= self.cell_type_threshold].index # adata_copied = adata_copied[~adata_copied.obsm["cell_type"].isin(celltype_to_remove), :] # data.data = adata_copied # return data adata = data.data one_hot_cell_types_df = adata.obsm["cell_type"] if not isinstance(one_hot_cell_types_df, pd.DataFrame): raise TypeError( f"Expected obsm['cell_type'] to be a pandas.DataFrame, but got {type(one_hot_cell_types_df)}") cellType_Counts = one_hot_cell_types_df.sum(axis=0) celltype_names_to_remove = cellType_Counts[cellType_Counts <= self.cell_type_threshold].index print(f"Found {len(celltype_names_to_remove)} cell types with counts <= {self.cell_type_threshold}.") if len(celltype_names_to_remove) > 0: print(f"Cell types to remove: {celltype_names_to_remove.tolist()}") # Show names if not celltype_names_to_remove.empty: # Check if the Index object is empty sub_df_to_remove = one_hot_cell_types_df[celltype_names_to_remove] is_cell_to_remove = sub_df_to_remove.sum(axis=1) > 0 keep_mask = ~is_cell_to_remove print(f"Keeping {keep_mask.sum()} cells out of {adata.n_obs}.") else: print("No cell types below threshold. Keeping all cells.") keep_mask = pd.Series(True, index=adata.obs_names) # adata = adata[keep_mask, :] data.filter_by_mask(keep_mask) return data
[docs]@register_preprocessor("filter", "cell") @add_mod_and_transform class FilterCellTransform(BaseTransform): def __init__(self, species: Literal["human", "mouse"] = "human", image_save_path: str = None, **kwargs): super().__init__(**kwargs) sc.settings.figdir = image_save_path self.species = species sc.settings.file_format_figs = "png" self.image_save_path = image_save_path def is_outlier(self, adata, metric: str, nmads: int): M = adata.obs[metric] floor = np.median(M) - nmads * median_abs_deviation(M) cell = np.median(M) + nmads * median_abs_deviation(M) outlier = (M < floor) | (cell < M) self.logger.info(f"metric:{metric} floor:{floor} cell:{cell}") return outlier def __call__(self, data: Data) -> Data: adata = data.data # mitochondrial genes, "MT-" for human, "Mt-" for mouse adata.var["mt"] = adata.var_names.str.startswith("MT-" if self.species == "human" else "Mt-") # ribosomal genes adata.var["ribo"] = adata.var_names.str.startswith(("RPS", "RPL")) # hemoglobin genes adata.var["hb"] = adata.var_names.str.contains("^HB[^(P)]") sc.pp.calculate_qc_metrics(adata, qc_vars=["mt", "ribo", "hb"], inplace=True, percent_top=[20], log1p=True) if self.image_save_path is not None: sc.pl.violin(adata, ["n_genes_by_counts", "total_counts", "pct_counts_mt"], jitter=0.4, multi_panel=True, show=False, save=True) sc.pl.scatter(adata, "total_counts", "n_genes_by_counts", color="pct_counts_mt", show=False, save=True) adata.obs["outlier"] = (self.is_outlier(adata, "log1p_total_counts", 5) | self.is_outlier(adata, "log1p_n_genes_by_counts", 5) | self.is_outlier(adata, "pct_counts_in_top_20_genes", 5)) adata.obs["mt_outlier"] = self.is_outlier(adata, "pct_counts_mt", 3) | (adata.obs["pct_counts_mt"] > 8) self.logger.info(f"Total number of cells: {adata.n_obs}") mask = (~adata.obs['outlier']) & (~adata.obs['mt_outlier']) # adata = adata[(~adata.obs.outlier) & (~adata.obs.mt_outlier)].copy() data.filter_by_mask(mask) adata = data.data self.logger.info(f"Number of cells after filtering of low quality cells: {adata.n_obs}") return data
@register_preprocessor("filter", "cell") @add_mod_and_transform class ScrubletTransform(BaseTransform): def __init__(self, image_save_path: Optional[str] = None, **kwargs): super().__init__(**kwargs) if image_save_path is not None: sc.settings.figdir = image_save_path self.image_save_path = image_save_path sc.settings.file_format_figs = "png" def __call__(self, data: Data) -> Data: adata = data.data sc.pp.scrublet(adata) if self.image_save_path is not None: sc.pl.scrublet_score_distribution(adata, show=False, save=True) self.logger.info(f"Original number of cells: {adata.n_obs}") mask = (~adata.obs['predicted_doublet']) data.filter_by_mask(mask) adata = data.data self.logger.info(f"Number of cells after filtering: {adata.n_obs}") return data