Source code for dance.transforms.graph.feature_feature_graph

import dgl
import dgl.nn as dglnn
import numpy as np
import torch
from scipy.sparse import coo_matrix
from scipy.stats import spearmanr

from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform
from dance.typing import Any, Dict, Optional
from dance.utils.matrix import dist_to_rbf


[docs]@register_preprocessor("graph", "feature") class FeatureFeatureGraph(BaseTransform): """Feature-feature similarity graph. Parameters ---------- threshold Edge similarity score threshold. positive_only Only use positive similarity score if set to ``True``. normalize_edges Normalize edge weights following GCN if set to ``True``. score_func Distance function to use, supported options are ``"pearson"``, ``"spearman"``, and ``"rbf"`` score_func_kwargs Optional kwargs passed to the score function, e.g. see :meth:`dance.utils.matrix.dist_to_rbf`. """ _DISPLAY_ATTRS = ("threshold", "positive_only", "normalize_edges", "score_func", "score_func_kwargs") def __init__(self, threshold: float = 0.3, *, positive_only: bool = False, normalize_edges: bool = True, score_func="pearson", score_func_kwargs: Optional[Dict[str, Any]] = None, **kwargs): super().__init__(**kwargs) self.threshold = threshold self.positive_only = positive_only self.normalize_edges = normalize_edges self.score_func = score_func self.score_func_kwargs = score_func_kwargs or {} def __call__(self, data): feat = data.get_feature(return_type="numpy") # Calculate correlation between features if self.score_func == "pearson": adj = np.corrcoef(feat.T) elif self.score_func == "spearman": adj = spearmanr(feat, axis=0)[0] elif self.score_func == "rbf": norm_vec = np.power(feat, 2).sum(0, keepdims=True) # NOTE: need to clip zero due to numerical errors caused by single # precision computationa leading to negative values. dist_mat = np.sqrt((norm_vec + norm_vec.T - 2 * feat.T @ feat).clip(0)) adj = dist_to_rbf(dist_mat, **self.score_func_kwargs) else: raise ValueError(f"Unknown similarity score function {self.score_func!r}, " "supported options are: 'pearson', 'spearman', 'rbf'") # Cast type feat = feat.astype(np.float32) adj = adj.astype(np.float32) # Apply threshold adj[np.logical_and(adj > -self.threshold, adj < self.threshold)] = 0 if self.positive_only: adj[adj < 0] = 0 # Initialize graph adj_coo = coo_matrix(adj) graph_data = (torch.from_numpy(adj_coo.row).int(), torch.from_numpy(adj_coo.col).int()) g = dgl.graph(graph_data, num_nodes=adj.shape[0]) g.ndata["feat"] = torch.from_numpy(feat.T) g.edata["weight"] = torch.ones(g.num_edges(), dtype=torch.float32) # Normalize edges if self.normalize_edges: norm = dglnn.EdgeWeightNorm() norm_edge_weight = norm(g, g.edata["weight"]) g.edata["weight"] = norm_edge_weight.float() data.data.uns[self.out] = g return data