Source code for dance.transforms.scn_feature

import itertools
from collections import defaultdict

import numpy as np
import pandas as pd
import scanpy as sc
from sklearn.metrics import r2_score

from dance import logger
from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform
from dance.transforms.stats import genestats_alpha, genestats_mu
from dance.typing import Dict, List, Optional, Tuple


[docs]@register_preprocessor("feature", "cell") class SCNFeature(BaseTransform): """Differential gene-pair feature used in SingleCellNet.""" _DISPLAY_ATTRS = ("num_top_genes", "alpha1", "alpha2", "mu", "num_top_gene_pairs", "max_gene_per_ct", "split_name") def __init__(self, num_top_genes: int = 10, alpha1: float = 0.05, alpha2: float = 0.001, mu: float = 2, num_top_gene_pairs: int = 25, max_gene_per_ct: int = 3, *, split_name: Optional[str] = "train", channel: Optional[str] = None, channel_type: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.num_top_genes = num_top_genes self.alpha1 = alpha1 self.alpha2 = alpha2 self.mu = mu self.num_top_gene_pairs = num_top_gene_pairs self.max_gene_per_ct = max_gene_per_ct self.split_name = split_name self.channel = channel self.channel_type = channel_type def __call__(self, data): split_idx = data.get_split_idx(self.split_name) all_exp_df = data.data.to_df(self.channel) # TODO: return as numpy or sparse csr to improve efficiency? #all_exp_df should contain gene names, so dimensionality reduction is not easy cell_type_df = data.get_feature(return_type="default", channel="cell_type", channel_type="obsm").iloc[split_idx] # Get normalized features adata = data.data[split_idx].copy() # sc.pp.filter_genes(adata, min_counts=1) # sc.pp.highly_variable_genes(adata, max_mean=4, subset=True) # sc.pp.scale(adata, max_value=10) # Filtering shouldn't be here norm_exp_df = adata.to_df() # cell_type_df = cell_type_df.loc[adata.obs_names] # not necessary, but kept here in case we subsample cells # Get differentially expressed genes and gene pairs cell_type_array = cell_type_df.columns.values[cell_type_df.values.argmax(1)] degs_dict = get_diff_exp_genes(norm_exp_df, cell_type_array, alpha1=self.alpha1, alpha2=self.alpha2, mu=self.mu, num_top_genes=self.num_top_genes) top_gene_pairs = get_top_gene_pairs(norm_exp_df, cell_type_array, degs_dict, num_top_pairs=self.num_top_gene_pairs, max_gene_per_ct=self.max_gene_per_ct) # Prepare binarized feature using the selected gene pairs scn_feat = query_transform(all_exp_df, top_gene_pairs) data.data.obsm[self.out] = scn_feat return data
# # TODO: move to dance.transforms.tools? or make it a transform? # def binGenes(geneStats, nbins=20, meanType="overall_mean"): # max = np.max(geneStats[meanType]) # min = np.min(geneStats[meanType]) # rrange = max - min # inc = rrange / nbins # threshs = np.arange(max, min, -1 * inc) # res = pd.DataFrame(index=geneStats.index.values, data=np.arange(0, geneStats.index.size, 1), columns=["bin"]) # for i in range(0, len(threshs)): # res.loc[geneStats[meanType] <= threshs[i], "bin"] = len(threshs) - i # return res def query_transform(exp_df: pd.DataFrame, gene_pairs: List[Tuple[str, str]]): """Transform expression data into SCN feature given selected gene pairs. Parameters ---------- exp_df Expression matrix (sample x gene). gene_pairs List of selected top differentiating gene pairs. Returns ------- gene_pair_diff_bin SCN feature. A binary matrix indicating whether the source genes have higher expression than the target genes in the top selected gene pairs. """ genes1, genes2 = map(list, zip(*gene_pairs)) gene_pair_diff_bin = (exp_df[genes1].values > exp_df[genes2].values).astype(float) gene_pair_diff_bin = pd.DataFrame(gene_pair_diff_bin, index=exp_df.index, columns=map("&".join, gene_pairs)) return gene_pair_diff_bin def get_top_gene_pairs(exp_df: pd.DataFrame, cell_type_array: np.ndarray, degs_dict: Dict[str, List[str]], *, num_top_pairs: int = 250, max_gene_per_ct: int = 3) -> List[Tuple[str, str]]: """Obtain top differentiating gene pairs. Parameters ---------- exp_df Expression matrix (sample x gene). cell_type_array 1-d array of cell-type information for each sample. degs_dict Dictionary of differentially expressed genes for each cell type. num_top_pairs Number of top differentiating gene pairs to get. max_gene_per_ct Maximum number of genes allowed to be attributed to a cell type (in the form of gene pairs). Returns ------- top_gene_pairs List of top differentiating gene pairs. """ top_gene_pairs = [] for cell_type, degs in degs_dict.items(): logger.info(f"Extracting top gene pairs for {cell_type}...") logger.debug(f"All DEGs:\n{degs}") logger.info(f"\tFirst five DEGs: {', '.join(degs[:5])}") gene_pairs = list(itertools.combinations(degs, 2)) pair_df = pd.DataFrame(gene_pairs, columns=["gene1", "gene2"]) pair_df["gene_pair"] = pair_df.apply("&".join, axis=1) gene_pair_diff_bin = (exp_df[pair_df["gene1"]].values > exp_df[pair_df["gene2"]].values).astype(float) gene_pair_diff_bin = pd.DataFrame(gene_pair_diff_bin, columns=pair_df["gene_pair"]) cell_type_mask = np.zeros(cell_type_array.size) cell_type_mask[cell_type_array == cell_type] = 1 gene_pair_scores = _get_deg_scores(gene_pair_diff_bin, cell_type_mask) best_gene_pairs = _get_best_gene_pairs(gene_pair_scores, gene_pairs, num_pairs=num_top_pairs, max_gene_per_ct=max_gene_per_ct) best_gene_pairs_str = ", ".join(["&".join(gene_pair) for gene_pair in best_gene_pairs[:5]]) logger.info(f"\tFirst five gene pairs: {best_gene_pairs_str}") top_gene_pairs.extend(best_gene_pairs) top_gene_pairs = sorted(set(top_gene_pairs)) return top_gene_pairs def _get_best_gene_pairs(gene_pair_scores: np.ndarray, gene_pairs: List[Tuple[str, str]], num_pairs: int = 50, max_gene_per_ct: int = 3) -> List[Tuple[str, str]]: valid_idx = np.where(~np.isnan(gene_pair_scores))[0] sorted_idx = valid_idx[gene_pair_scores[valid_idx].argsort()[::-1]] best_gene_pairs = [] count_dict = defaultdict(int) for idx in sorted_idx: gene_pair = gene1, gene2 = gene_pairs[idx] if (count_dict[gene1] < max_gene_per_ct) and (count_dict[gene2] < max_gene_per_ct): best_gene_pairs.append(gene_pair) count_dict[gene1] += 1 count_dict[gene2] += 1 if len(best_gene_pairs) == num_pairs: break else: # did not obtain enough number of gene pairs required logger.warning(f"Ran out of gene pairs to select (total_pairs={sorted_idx.size:,}), target number: " f"{num_pairs:,}, number of gene pairs selected: {len(best_gene_pairs):,}") return best_gene_pairs def get_diff_exp_genes(exp_df: pd.DataFrame, cell_type_array: np.ndarray, *, num_top_genes: int = 100, threshold: float = 0, alpha1: float = 0.05, alpha2: float = 0.001, mu: float = 2) -> Tuple[Dict[str, List[str]], List[str]]: """Get differntially expressed genes via regression. Parameters ---------- exp_df Expression matrix (sample x gene). cell_type_array 1-d array of cell-type information for each sample. num_top_genes Number of top differentially expressed genes to use. threshold Gene expression threshold parameters. alpha1 Alpha 1 threshold parameter. alpha2 Alpha 2 threshold parameter. mu mu threshold parameter. Returns ------- degs_dict Dictionary of selected top differentially expressed genes for each cell type. """ alpha_df = genestats_alpha(exp_df, threshold=threshold) mu_df = genestats_mu(exp_df, threshold=threshold) cond1 = alpha_df > alpha1 cond2 = alpha_df > alpha2 cond3 = mu_df > mu indicator = np.logical_or(cond1, np.logical_and(cond2, cond3)) selected_genes = exp_df.columns.values[indicator] degs_dict = _get_degs_dict(exp_df.loc[:, selected_genes], cell_type_array, num_top_genes) return degs_dict def _get_degs_dict(exp_df, cell_type_array, num_top_genes, both_ends: bool = True) -> Dict[str, List[str]]: # NOTE: when both_ends is set, the actual number of selected genes will be at most doubled degs_dict = {} for cell_type in np.unique(cell_type_array): cell_type_mask = np.zeros(cell_type_array.size) cell_type_mask[cell_type_array == cell_type] = 1 cval = _get_deg_scores(exp_df, cell_type_mask) valid_idx = np.where(~np.isnan(cval))[0] sorted_idx = cval[valid_idx].argsort()[::-1] # Select positively differentially expressed genes selected_sorted_idx = sorted_idx[:num_top_genes].tolist() if both_ends: # add negatively differentially expressed genes selected_sorted_idx.extend(sorted_idx[-num_top_genes:].tolist()) selected = valid_idx[sorted(set(selected_sorted_idx))] degs_dict[cell_type] = exp_df.columns[selected].tolist() return degs_dict def _get_deg_scores(exp_df, cell_type_mask) -> np.ndarray: y = np.vstack([cell_type_mask, np.ones(len(cell_type_mask))]).T p = np.linalg.lstsq(y, exp_df, rcond=None)[0] exp_recon = y @ p r2 = r2_score(exp_df.values, exp_recon, multioutput="raw_values").clip(0) cval = np.sqrt(r2) * np.sign(p[0]) return cval