Source code for dance.transforms.graph.cell_feature_graph

import dgl
import numpy as np
import torch

from dance.transforms.base import BaseTransform
from dance.transforms.cell_feature import WeightedFeaturePCA
from dance.typing import LogLevel, Optional


[docs]class CellFeatureGraph(BaseTransform): def __init__(self, cell_feature_channel: str, gene_feature_channel: Optional[str] = None, *, mod: Optional[str] = None, normalize_edges: bool = True, **kwargs): super().__init__(**kwargs) self.cell_feature_channel = cell_feature_channel self.gene_feature_channel = gene_feature_channel or cell_feature_channel self.mod = mod self.normalize_edges = normalize_edges def _alternative_construction(self, data): # TODO: Try this alternative construction x_sparse = data.get_feature(return_type="sparse", channel=self.cell_feature_channel, mod=self.mod) g = dgl.bipartite_from_scipy(x_sparse, utype="cell", etype="expression", vtype="feature", eweight_name="weight") g = dgl.ToSimple()(g) g = dgl.AddSelfLoop(edge_feat_names="weight")(g) g = dgl.AddReverse(copy_edata=True)(g) g.ndata["weight"] = dgl.nn.EdgeWeightNorm(norm="both")(g, g.ndata["weight"]) data.data.uns[self.out] = g return data def __call__(self, data): feat = data.get_feature(return_type="default", mod=self.mod) num_cells, num_feats = feat.shape row, col = np.nonzero(feat) edata = np.array(feat[row, col]).ravel()[:, None] self.logger.info(f"Number of nonzero entries: {edata.size:,}") self.logger.info(f"Nonzero rate = {edata.size / num_cells / num_feats:.1%}") row = row + num_feats # offset by feature nodes col, row = np.hstack((col, row)), np.hstack((row, col)) # convert to undirected edata = np.vstack((edata, edata)) # Convert to tensors col = torch.LongTensor(col) row = torch.LongTensor(row) edata = torch.FloatTensor(edata) # Initialize cell-gene graph g = dgl.graph((row, col)) g.edata["weight"] = edata # FIX: change to feat_id g.ndata["cell_id"] = torch.concat((torch.arange(num_feats, dtype=torch.int32), -torch.ones(num_cells, dtype=torch.int32))) # yapf: disable g.ndata["feat_id"] = torch.concat((-torch.ones(num_feats, dtype=torch.int32), torch.arange(num_cells, dtype=torch.int32))) # yapf: disable # Normalize edges and add self-loop if self.normalize_edges: in_deg = g.in_degrees() for i in range(g.number_of_nodes()): src, dst, eidx = g.in_edges(i, form="all") if src.shape[0] > 0: edge_w = g.edata["weight"][eidx] g.edata["weight"][eidx] = in_deg[i] * edge_w / edge_w.sum() g.add_edges(g.nodes(), g.nodes(), {"weight": torch.ones(g.number_of_nodes())[:, None]}) gene_feature = data.get_feature(return_type="torch", channel=self.gene_feature_channel, mod=self.mod, channel_type="varm") cell_feature = data.get_feature(return_type="torch", channel=self.cell_feature_channel, mod=self.mod, channel_type="obsm") g.ndata["features"] = torch.vstack((gene_feature, cell_feature)) data.data.uns[self.out] = g return data
[docs]class PCACellFeatureGraph(BaseTransform): _DISPLAY_ATTRS = ("n_components", "split_name") def __init__( self, n_components: int = 400, split_name: Optional[str] = None, *, normalize_edges: bool = True, feat_norm_mode: Optional[str] = None, feat_norm_axis: int = 0, mod: Optional[str] = None, log_level: LogLevel = "WARNING", ): super().__init__(log_level=log_level) self.n_components = n_components self.split_name = split_name self.normalize_edges = normalize_edges self.feat_norm_mode = feat_norm_mode self.feat_norm_axis = feat_norm_axis self.mod = mod def __call__(self, data): WeightedFeaturePCA(self.n_components, self.split_name, feat_norm_mode=self.feat_norm_mode, feat_norm_axis=self.feat_norm_axis, log_level=self.log_level)(data) CellFeatureGraph(cell_feature_channel="WeightedFeaturePCA", mod=self.mod, normalize_edges=self.normalize_edges, log_level=self.log_level)(data) return data
class CellFeatureBipartiteGraph(BaseTransform): def __init__(self, cell_feature_channel: str, *, mod: Optional[str] = None, **kwargs): super().__init__(**kwargs) self.cell_feature_channel = cell_feature_channel self.mod = mod def __call__(self, data): feat = data.get_feature(channel=self.cell_feature_channel, return_type="sparse", mod=self.mod) g = dgl.bipartite_from_scipy(feat, utype='cell', etype='cell2feature', vtype='feature', eweight_name='weight') g.nodes['cell'].data['id'] = torch.arange(feat.shape[0]).long() g.nodes['feature'].data['id'] = torch.arange(feat.shape[1]).long() g = dgl.AddReverse(copy_edata=True, sym_new_etype=True)(g) if self.mod is None: data.data.uns['g'] = g else: data.data[self.mod].uns['g'] = g return data