Source code for dance.transforms.graph.scmogcn_graph

import os
import pickle
from collections import defaultdict

import dgl
import numpy as np
import scipy.sparse
import torch
from sklearn.decomposition import TruncatedSVD

from dance import logger
from dance.data.base import Data
from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform
from dance.typing import Union


def read_gmt(entrez_string, symbol_string):
    gene_list = entrez_string.split()

    gene_sets_entrez = defaultdict(list)
    indicator = 0
    for ele in gene_list:
        if ele.isnumeric():
            gene_sets_entrez[gene_set_name].append(ele)
        elif indicator == 1:
            indicator = 0
        elif indicator == 0:
            indicator = 1
            gene_set_name = ele

    gene_list = symbol_string.split()
    gene_sets_symbols = defaultdict(list)

    for ele in gene_list:
        if ele in gene_sets_entrez:
            gene_set_name = ele
        elif not ele.startswith('http://'):
            gene_sets_symbols[gene_set_name].append(ele)
    return gene_sets_symbols


def create_pathway_graph(gex_features: scipy.sparse.spmatrix, gene_names: Union[str], pathway_weight: str,
                         pathway_threshold: float, subtask: str, pathway_path: str):
    """Generate nodes, edges and edge weights for pathway dataset.

    Parameters
    ----------
    gex_features: scipy.sparse.spmatrix
        Gene expression data.
    gene_names: Union[str]
        Gene names.
    pathway_weight: str
        Weighting scheme for pathway filtering.
    pathway_threshold: float
        Threshold for pathway filtering.
    subtask: str
        Subtask name.
    pathway_path: str
        Path to pathway file.

    Returns
    --------
    uu: List[int]
        Predecessor node id of each edge.
    vv: List[int]
        Successor node id of each edge.
    ee: List[float]
        Edge weight of each edge.

    """

    uu = []
    vv = []
    ee = []

    pk_path = f'pw_{subtask}_{pathway_weight}.pkl'
    if os.path.exists(pk_path):
        logger.warning("Pathway file exist. Load pickle file by default. "
                       "Auguments '--pathway_weight' and '--pathway_path' will not take effect.")
        uu, vv, ee = pickle.load(open(pk_path, 'rb'))
    else:
        # Load Original Pathway File
        with open(pathway_path + '.entrez.gmt') as gmt:
            entrez_string = gmt.read()
        with open(pathway_path + '.symbols.gmt') as gmt:
            symbols_string = gmt.read()
        gene_sets_symbols = read_gmt(entrez_string, symbols_string)

        pw = [i[1] for i in gene_sets_symbols.items()]
        # Generate New Pathway Data
        new_pw = []
        for i in pw:
            new_pw.append([])
            for j in i:
                if j in gene_names:
                    new_pw[-1].append(gene_names.index(j))

        if pathway_weight == 'cos':
            for i in new_pw:
                for j in i:
                    for k in i:
                        if j != k:
                            uu.append(j)
                            vv.append(k)
                            sj = np.sqrt(np.dot(gex_features[:, j].T, gex_features[:, j]).item())
                            sk = np.sqrt(np.dot(gex_features[:, k].T, gex_features[:, k]).item())
                            jk = np.dot(gex_features[:, j].T, gex_features[:, k])
                            cossim = jk / sj / sk
                            ee.append(cossim.item())
        elif pathway_weight == 'one':
            for i in new_pw:
                for j in i:
                    for k in i:
                        if j != k:
                            uu.append(j)
                            vv.append(k)
                            ee.append(1.)
        elif pathway_weight == 'pearson':
            corr = np.corrcoef(gex_features.T.todense())
            for i in new_pw:
                for j in i:
                    for k in i:
                        if j != k:
                            uu.append(j)
                            vv.append(k)
                            ee.append(1 - corr[j][k])
        elif pathway_weight == 'spearman':
            spe = scipy.stats.spearmanr(gex_features.toarray())[0]
            if gex_features.toarray().shape[0] == 2:
                spe = np.array([[1, spe], [spe, 1]])
            for i in new_pw:
                for j in i:
                    for k in i:
                        if j != k:
                            uu.append(j)
                            vv.append(k)
                            ee.append(1 - spe[j][k])

        pickle.dump([uu, vv, ee], open(pk_path, 'wb'))

    # Apply Threshold
    nu = []
    nv = []
    ne = []

    for i in range(len(uu)):
        if abs(ee[i]) > pathway_threshold:
            ne.append(ee[i])
            nu.append(uu[i])
            nv.append(vv[i])
    uu, vv, ee = nu, nv, ne

    return uu, vv, ee


def construct_enhanced_feature_graph(u, v, e, train_size, feature_size, cell_node_features, inductive=False,
                                     enhance_graph=None, _test_graph=False):
    """Generate a feature-cell graph, enhanced with domain-knowledge (e.g. pathway).

    Parameters
    ----------
    u: torch.Tensor
        1-dimensional tensor. Cell node id of each cell-feature edge.
    v: torch.Tensor
        1-dimensional tensor. Feature node id of each cell-feature edge.
    e: torch.Tensor
        1-dimensional tensor. Weight of each cell-feature edge.
    train_size: int
        Number of training cells.
    feature_size: int
        Number of features.
    cell_node_features: torch.Tensor
        Cell node initial features.
    inductive: bool
        Whether to use inductive learning.
    enhance_graph: dgl.DGLGraph
        Enhanced graph.
    _test_graph: bool
        Whether to use test graph.


    Returns
    --------
    graph: DGLGraph
        The generated graph.

    """

    if enhance_graph is None:

        graph_data = {
            ('cell', 'cell2feature', 'feature'): (u, v),
            ('feature', 'feature2cell', 'cell'): (v, u),
        }

        graph = dgl.heterograph(graph_data)

        if inductive:
            graph.nodes['cell'].data['id'] = cell_node_features[:train_size] if not _test_graph else cell_node_features
        else:
            graph.nodes['cell'].data['id'] = cell_node_features
        feature_size = min(graph.num_nodes('feature'), feature_size)
        graph.nodes['feature'].data['id'] = torch.arange(feature_size).long()
        graph.edges['feature2cell'].data['weight'] = e
        graph.edges['cell2feature'].data['weight'] = e[:graph.edges(etype='cell2feature')[0].shape[0]]

    else:
        uu, vv, ee = enhance_graph

        graph_data = {
            ('cell', 'cell2feature', 'feature'): (u, v),
            ('feature', 'feature2cell', 'cell'): (v, u),
            ('feature', 'pathway', 'feature'): (uu, vv),
        }
        graph = dgl.heterograph(graph_data)

        if inductive:
            graph.nodes['cell'].data['id'] = cell_node_features[:train_size] if not _test_graph else cell_node_features
        else:
            graph.nodes['cell'].data['id'] = cell_node_features
        graph.nodes['feature'].data['id'] = torch.arange(feature_size).long()
        graph.edges['feature2cell'].data['weight'] = e.float()
        graph.edges['cell2feature'].data['weight'] = e[:graph.edges(etype='cell2feature')[0].shape[0]].float()
        graph.edges['pathway'].data['weight'] = torch.tensor(ee).float()

    return graph


[docs]@register_preprocessor("graph", "cell") class ScMoGNNGraph(BaseTransform): """Construct the cell-feature graph object for scmognn. Parameters ---------- inductive: bool Whether to use inductive learning. Default: False. cell_init: str Initialization method for cell features. Default: 'none'. pathway: bool Whether to use pathway information. Default: True. subtask: str Subtask name. Default: 'gex2adt'. pathway_weight: str Weighting scheme for pathway filtering. Default: None. pathway_threshold: float Threshold for pathway filtering. Default: 0. pathway_path: str Path to pathway file. Default: 'data/h.all.v7.4'. Returns -------- g: DGLGraph The generated graph. """ def __init__(self, inductive: bool = False, cell_init: str = 'none', pathway=True, subtask='openproblems_bmmc_cite_phase2_rna', pathway_weight=None, pathway_threshold: float = 0., pathway_path: str = 'data/h.all.v7.4', **kwargs): super().__init__(**kwargs) self.inductive = inductive self.cell_init = cell_init self.pathway = pathway self.subtask = subtask self.pathway_weight = pathway_weight self.pathway_threshold = pathway_threshold self.pathway_path = pathway_path def __call__(self, data: Data) -> Data: x_train, _ = data.get_train_data(return_type="numpy") x_train_sparse, _ = data.get_train_data(return_type="sparse") x_test_sparse, _ = data.get_test_data(return_type="sparse") cell_size = x_train_sparse.shape[0] + x_test_sparse.shape[0] if not self.inductive else x_train_sparse.shape[0] train_size = x_train_sparse.shape[0] feature_size = x_train_sparse.shape[1] if self.cell_init == 'none': cell_node_features = torch.ones(cell_size).long() elif self.cell_init == 'svd': embedder_mod1 = TruncatedSVD(n_components=100) X_train_np = embedder_mod1.fit_transform(x_train_sparse) X_test_np = embedder_mod1.transform(x_test_sparse) cell_node_features = torch.cat([torch.from_numpy(X_train_np), torch.from_numpy(X_test_np)], 0).float() if self.pathway: gene_names = data.data['mod1'].var_names.tolist() enhance_graph = create_pathway_graph(gex_features=x_train, gene_names=gene_names, pathway_weight=self.pathway_weight, pathway_threshold=self.pathway_threshold, subtask=self.subtask, pathway_path=self.pathway_path) else: enhance_graph = None if self.inductive: u = torch.from_numpy( np.concatenate([np.array(t.nonzero()[0] + i) for i, t in enumerate(x_train_sparse)], axis=0)) v = torch.from_numpy(np.concatenate([np.array(t.nonzero()[1]) for t in x_train_sparse], axis=0)) e = torch.from_numpy(x_train_sparse.tocsr().data).float() g = construct_enhanced_feature_graph(u, v, e, train_size, feature_size, cell_node_features, self.inductive, enhance_graph, _test_graph=False) u = torch.from_numpy(np.concatenate([np.array(t.nonzero()[0] + i) for i, t in enumerate(x_train_sparse)] + \ [np.array(t.nonzero()[0] + i + train_size) for i, t in enumerate(x_test_sparse)], axis=0)) v = torch.from_numpy(np.concatenate([np.array(t.nonzero()[1]) for t in x_train_sparse] + \ [np.array(t.nonzero()[1]) for t in x_test_sparse], axis=0)) e = torch.from_numpy(np.concatenate( [x_train_sparse.tocsr().data, x_test_sparse.tocsr().data], axis=0)).float() gtest = construct_enhanced_feature_graph(u, v, e, train_size, feature_size, cell_node_features, self.inductive, enhance_graph, _test_graph=True) data.data.uns['g'] = g data.data.uns['gtest'] = gtest else: u = torch.from_numpy(np.concatenate([np.array(t.nonzero()[0] + i) for i, t in enumerate(x_train_sparse)] + \ [np.array(t.nonzero()[0] + i + train_size) for i, t in enumerate(x_test_sparse)], axis=0)) v = torch.from_numpy(np.concatenate([np.array(t.nonzero()[1]) for t in x_train_sparse] + \ [np.array(t.nonzero()[1]) for t in x_test_sparse], axis=0)) e = torch.from_numpy(np.concatenate( [x_train_sparse.tocsr().data, x_test_sparse.tocsr().data], axis=0)).float() g = construct_enhanced_feature_graph(u, v, e, train_size, feature_size, cell_node_features, self.inductive, enhance_graph) data.data.uns['g'] = g return data