Source code for dance.transforms.cell_feature

import numpy as np
import pandas as pd
from sklearn.decomposition import PCA

from dance.transforms.base import BaseTransform
from dance.typing import Optional
from dance.utils.matrix import normalize


[docs]class WeightedFeaturePCA(BaseTransform): """Compute the weighted gene PCA as cell features. Given a gene expression matrix of dimension (cell x gene), the gene PCA is first compured. Then, the representation of each cell is computed by taking the weighted sum of the gene PCAs based on that cell's gene expression values. Parameters ---------- n_components Number of PCs to use. split_name Which split to use to compute the gene PCA. If not set, use all data. """ _DISPLAY_ATTRS = ("n_components", "split_name", "feat_norm_mode", "feat_norm_axis") def __init__(self, n_components: int = 400, split_name: Optional[str] = None, feat_norm_mode: Optional[str] = None, feat_norm_axis: int = 0, **kwargs): super().__init__(**kwargs) self.n_components = n_components self.split_name = split_name self.feat_norm_mode = feat_norm_mode self.feat_norm_axis = feat_norm_axis def __call__(self, data): feat = data.get_x(self.split_name) # cell x genes if self.feat_norm_mode is not None: self.logger.info(f"Normalizing feature before PCA decomposition with mode={self.feat_norm_mode} " f"and axis={self.feat_norm_axis}") feat = normalize(feat, mode=self.feat_norm_mode, axis=self.feat_norm_axis) gene_pca = PCA(n_components=self.n_components) self.logger.info(f"Start decomposing {self.split_name} features {feat.shape} (k={self.n_components})") gene_feat = gene_pca.fit_transform(feat.T) # decompose into gene features self.logger.info(f"Total explained variance: {gene_pca.explained_variance_ratio_.sum():.2%}") x = data.get_x() cell_feat = normalize(x, mode="normalize", axis=1) @ gene_feat data.data.obsm[self.out] = cell_feat.astype(np.float32) data.data.varm[self.out] = gene_feat.astype(np.float32) return data
[docs]class CellPCA(BaseTransform): _DISPLAY_ATTRS = ("n_components", ) def __init__(self, n_components: int = 400, *, channel: Optional[str] = None, mod: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.n_components = n_components self.channel = channel self.mod = mod def __call__(self, data): feat = data.get_feature(return_type="numpy", channel=self.channel, mod=self.mod) pca = PCA(n_components=self.n_components) self.logger.info(f"Start generating cell PCA features {feat.shape} (k={self.n_components})") cell_feat = pca.fit_transform(feat) evr = pca.explained_variance_ratio_ self.logger.info(f"Top 10 explained variances: {evr[:10]}") self.logger.info(f"Total explained variance: {evr.sum():.2%}") data.data.obsm[self.out] = cell_feat return data
[docs]class BatchFeature(BaseTransform): """Assign statistical batch features for each cell.""" def __init__(self, *, channel: Optional[str] = None, mod: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.channel = channel self.mod = mod def __call__(self, data): # TODO: use get_feature; move mod1 to mod; replace for loop with np cells = [] columns = [ "cell_mean", "cell_std", "nonzero_25%", "nonzero_50%", "nonzero_75%", "nonzero_max", "nonzero_count", "nonzero_mean", "nonzero_std", "batch", ] # yapf: disable ad_input = data.data["mod1"] bcl = list(ad_input.obs["batch"]) print(set(bcl)) for i, cell in enumerate(ad_input.X): cell = cell.toarray() nz = cell[np.nonzero(cell)] if len(nz) == 0: raise ValueError("Error: one cell contains all zero features.") cells.append([ cell.mean(), cell.std(), np.percentile(nz, 25), np.percentile(nz, 50), np.percentile(nz, 75), cell.max(), len(nz) / 1000, nz.mean(), nz.std(), bcl[i] ]) cell_features = pd.DataFrame(cells, columns=columns) batch_source = cell_features.groupby("batch").mean().reset_index() batch_list = batch_source.batch.tolist() batch_source = batch_source.drop("batch", axis=1).to_numpy().tolist() b2i = dict(zip(batch_list, range(len(batch_list)))) batch_features = [] for b in ad_input.obs["batch"]: batch_features.append(batch_source[b2i[b]]) batch_features = np.array(batch_features).astype(float) data.data["mod1"].obsm["batch_features"] = batch_features return data