Source code for dance.modules.single_modality.clustering.scdeepcluster

"""Reimplementation of scDeepCluster.

Extended from https://github.com/ttgump/scDeepCluster

Reference
----------
Tian, Tian, et al. "Clustering single-cell RNA-seq data with a model-based deep learning approach." Nature Machine
Intelligence 1.4 (2019): 191-198.

"""
import math

import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.cluster import KMeans
from torch.nn import Parameter
from torch.utils.data import DataLoader, TensorDataset

from dance import logger
from dance.modules.base import BaseClusteringMethod, TorchNNPretrain
from dance.transforms import AnnDataTransform, Compose, SaveRaw, SetConfig
from dance.typing import Any, List, LogLevel, Optional, Tuple
from dance.utils.loss import ZINBLoss


def buildNetwork(layers: List[int], network_type: str, activation: str = "relu"):
    """Build network layer.

    Parameters
    ----------
    layers
        Dimensions of layers.
    network_type
        Type of network.
    activation
        Activation function.

    Returns
    -------
    Built network.

    """
    net = []
    for i in range(1, len(layers)):
        net.append(nn.Linear(layers[i - 1], layers[i]))
        if activation == "relu":
            net.append(nn.ReLU())
        elif activation == "sigmoid":
            net.append(nn.Sigmoid())
    net = nn.Sequential(*net)
    return net


def euclidean_dist(x, y):
    """Calculate Euclidean distance between x and y."""
    return torch.sum(torch.square(x - y), dim=1)


[docs]class ScDeepCluster(nn.Module, TorchNNPretrain, BaseClusteringMethod): """ScDeepCluster class. Parameters ---------- input_dim Dimension of encoder input. z_dim Dimension of embedding. encodeLayer Dimensions of encoder layers. decodeLayer Dimensions of decoder layers. activation Activation function. sigma Parameter of Gaussian noise. alpha Parameter of soft assign. gamma Parameter of cluster loss. device Computing device. pretrain_path Path to pretrained weights. """ def __init__(self, input_dim, z_dim, encodeLayer=[], decodeLayer=[], activation="relu", sigma=1., alpha=1., gamma=1., device="cuda", pretrain_path: Optional[str] = None): super().__init__() self.z_dim = z_dim self.activation = activation self.sigma = sigma self.alpha = alpha self.gamma = gamma self.device = device self.pretrain_path = pretrain_path self.encoder = buildNetwork([input_dim] + encodeLayer, network_type="encode", activation=activation) self.decoder = buildNetwork([z_dim] + decodeLayer, network_type="decode", activation=activation) self._enc_mu = nn.Linear(encodeLayer[-1], z_dim) self._dec_mean = nn.Sequential(nn.Linear(decodeLayer[-1], input_dim), MeanAct()) self._dec_disp = nn.Sequential(nn.Linear(decodeLayer[-1], input_dim), DispAct()) self._dec_pi = nn.Sequential(nn.Linear(decodeLayer[-1], input_dim), nn.Sigmoid()) self.zinb_loss = ZINBLoss().to(self.device) self.to(device) @staticmethod def preprocessing_pipeline(log_level: LogLevel = "INFO"): return Compose( AnnDataTransform(sc.pp.filter_genes, min_counts=1), AnnDataTransform(sc.pp.filter_cells, min_counts=1), SaveRaw(), AnnDataTransform(sc.pp.normalize_total), AnnDataTransform(sc.pp.log1p), AnnDataTransform(sc.pp.scale), SetConfig({ "feature_channel": [None, None, "n_counts"], "feature_channel_type": ["X", "raw_X", "obs"], "label_channel": "Group", }), log_level=log_level, )
[docs] def save_model(self, path): """Save model to path. Parameters ---------- path Path to save model. """ torch.save(self.state_dict(), path)
[docs] def load_model(self, path): """Load model from path. Parameters ---------- path Path to load model. """ pretrained_dict = torch.load(path, map_location=lambda storage, loc: storage) model_dict = self.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) self.load_state_dict(model_dict)
[docs] def soft_assign(self, z): """Soft assign q with z. Parameters ---------- z Embedding. Returns ------- q Soft label. """ q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha) q = q**((self.alpha + 1.0) / 2.0) q = (q.t() / torch.sum(q, dim=1)).t() return q
[docs] def target_distribution(self, q): """Calculate auxiliary target distribution p with q. Parameters ---------- q Soft label. Returns ------- p Target distribution. """ p = q**2 / q.sum(0) return (p.t() / p.sum(1)).t()
[docs] def forwardAE(self, x): """Forward propagation of autoencoder. Parameters ---------- x Input features. Returns ------- z0 Embedding. _mean Data mean from ZINB. _disp Data dispersion from ZINB. _pi Data dropout probability from ZINB. """ h = self.encoder(x + torch.randn_like(x) * self.sigma) z = self._enc_mu(h) h = self.decoder(z) _mean = self._dec_mean(h) _disp = self._dec_disp(h) _pi = self._dec_pi(h) h0 = self.encoder(x) z0 = self._enc_mu(h0) return z0, _mean, _disp, _pi
[docs] def forward(self, x): """Forward propagation. Parameters ---------- x Input features. Returns ------- z0 Embedding. q Soft label. _mean Data mean from ZINB. _disp Data dispersion from ZINB. _pi Data dropout probability from ZINB. """ h = self.encoder(x + torch.randn_like(x) * self.sigma) z = self._enc_mu(h) h = self.decoder(z) _mean = self._dec_mean(h) _disp = self._dec_disp(h) _pi = self._dec_pi(h) h0 = self.encoder(x) z0 = self._enc_mu(h0) q = self.soft_assign(z0) return z0, q, _mean, _disp, _pi
[docs] def encodeBatch(self, x, batch_size=256): """Batch encoder. Parameters ---------- x Input features. batch_size Size of batch. Returns ------- encoded Embedding. """ self.eval() encoded = [] num = x.shape[0] num_batch = int(math.ceil(1.0 * x.shape[0] / batch_size)) for batch_idx in range(num_batch): xbatch = x[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] inputs = xbatch.to(self.device) z, _, _, _ = self.forwardAE(inputs) encoded.append(z.data) encoded = torch.cat(encoded, dim=0) return encoded.to(self.device)
[docs] def cluster_loss(self, p, q): """Calculate cluster loss. Parameters ---------- p Target distribution. q Soft label. Returns ------- loss Cluster loss. """ def kld(target, pred): return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=-1)) kldloss = kld(p, q) return self.gamma * kldloss
[docs] def pretrain(self, x, x_raw, n_counts, batch_size=256, lr=0.001, epochs=400): """Pretrain autoencoder. Parameters ---------- x Input features. x_raw Raw input features. n_counts Total counts for each cell. batch_size Size of batch. lr Learning rate. epochs Number of epochs. """ self.train() size_factor = torch.tensor(n_counts / np.median(n_counts)) dataset = TensorDataset(torch.Tensor(x), torch.Tensor(x_raw), size_factor) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True) for epoch in range(epochs): loss_val = 0 for batch_idx, (x_batch, x_raw_batch, sf_batch) in enumerate(dataloader): x_tensor = x_batch.to(self.device) x_raw_tensor = x_raw_batch.to(self.device) sf_tensor = sf_batch.to(self.device) _, mean_tensor, disp_tensor, pi_tensor = self.forwardAE(x_tensor) loss = self.zinb_loss(x=x_raw_tensor, mean=mean_tensor, disp=disp_tensor, pi=pi_tensor, scale_factor=sf_tensor) optimizer.zero_grad() loss.backward() optimizer.step() loss_val += loss.item() * len(x_batch) if epoch % 100 == 0: logger.info("Pretrain epoch %3d, ZINB loss: %.8f", epoch + 1, loss_val / x.shape[0])
[docs] def fit( self, inputs: Tuple[np.ndarray, np.ndarray, np.ndarray], y: np.ndarray, n_clusters: int = 10, init_centroid: Optional[List[int]] = None, y_pred_init: Optional[List[int]] = None, lr: float = 1, batch_size: int = 256, epochs: int = 10, update_interval: int = 1, tol: float = 1e-3, pt_batch_size: int = 256, pt_lr: float = 0.001, pt_epochs: int = 400, ): """Train model. Parameters ---------- inputs A tuple containing (1) the input features, (2) the raw input features, and (3) the total counts per cell. y True label. Used for model selection. n_clusters Number of clusters. init_centroid Initialization of centroids. If None, perform kmeans to initialize cluster centers. y_pred_init Predicted label for initialization. lr Learning rate. batch_size Size of batch. epochs Number of epochs. update_interval Update interval of soft label and target distribution. tol Tolerance for training loss. pt_batch_size Pretraining batch size. pt_lr Pretraining learning rate. pt_epochs pretraining epochs. """ x, x_raw, n_counts = inputs self._pretrain(x, x_raw, n_counts, batch_size=pt_batch_size, lr=pt_lr, epochs=pt_epochs, force_pretrain=True) self.train() x = torch.tensor(x, dtype=torch.float32) x_raw = torch.tensor(x_raw, dtype=torch.float32) size_factor = torch.FloatTensor(n_counts / np.median(n_counts)) self.mu = Parameter(torch.Tensor(n_clusters, self.z_dim).to(self.device)) optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95) logger.info("Initializing cluster centers with kmeans.") if init_centroid is None: kmeans = KMeans(n_clusters, n_init=20) data = self.encodeBatch(x) self.y_pred = kmeans.fit_predict(data.data.cpu().numpy()) self.y_pred_last = self.y_pred self.mu.data.copy_(torch.tensor(kmeans.cluster_centers_, dtype=torch.float32)) else: self.mu.data.copy_(torch.tensor(init_centroid, dtype=torch.float32)) self.y_pred = y_pred_init self.y_pred_last = self.y_pred num = x.shape[0] num_batch = int(math.ceil(1.0 * x.shape[0] / batch_size)) aris = [] P = {} Q = {} delta_label = np.inf for epoch in range(epochs): if epoch % update_interval == 0: # update the targe distribution p latent = self.encodeBatch(x.to(self.device)) q = self.soft_assign(latent) self.q = q p = self.target_distribution(q).data self.y_pred = self.predict() p_ = {f"epoch{epoch}": p} q_ = {f"epoch{epoch}": q} P = {**P, **p_} Q = {**Q, **q_} # check stop criterion if False: delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num self.y_pred_last = self.y_pred if epoch > 0 and delta_label < tol: logger.info("Reach tolerance threshold (%.3e < %.3e). Stopping training.", delta_label, tol) break # calculate ari score for model selection ari = self.score(None, y) aris.append(ari) # train 1 epoch for clustering loss train_loss = 0.0 recon_loss_val = 0.0 cluster_loss_val = 0.0 for batch_idx in range(num_batch): xbatch = x[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] xrawbatch = x_raw[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] sfbatch = size_factor[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] pbatch = p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] optimizer.zero_grad() inputs = xbatch.to(self.device) rawinputs = xrawbatch.to(self.device) sfinputs = sfbatch.to(self.device) target = pbatch.to(self.device) zbatch, qbatch, meanbatch, dispbatch, pibatch = self.forward(inputs) cluster_loss = self.cluster_loss(target, qbatch) recon_loss = self.zinb_loss(rawinputs, meanbatch, dispbatch, pibatch, sfinputs) loss = cluster_loss * self.gamma + recon_loss loss.backward() optimizer.step() cluster_loss_val += cluster_loss.item() * len(inputs) recon_loss_val += recon_loss.item() * len(inputs) train_loss += loss.item() * len(inputs) if epoch % 50 == 0: logger.info("Epoch %3d: Total: %.8f, Clustering Loss: %.8f, ZINB Loss: %.8f", epoch + 1, train_loss / num, cluster_loss_val / num, recon_loss_val / num) index = update_interval * np.argmax(aris) self.q = Q[f"epoch{index}"]
[docs] def predict_proba(self, x: Optional[Any] = None) -> np.ndarray: """Get the predicted propabilities for each cell. Parameters ---------- x Not used, for compatibility with the BaseClusteringMethod class. Returns ------- pred_prop Predicted probability for each cell. """ pred_prob = self.q.detach().clone().cpu().numpy() return pred_prob
[docs] def predict(self, x: Optional[Any] = None) -> np.ndarray: """Get predictions from the trained model. Parameters ---------- x Not used, for compatibility with the BaseClusteringMethod class. Returns ------- pred Predicted clustering assignment for each cell. """ pred = self.predict_proba().argmax(1) return pred
class MeanAct(nn.Module): """Mean activation class.""" def __init__(self): super().__init__() def forward(self, x): return torch.clamp(torch.exp(x), min=1e-5, max=1e6) class DispAct(nn.Module): """Dispersion activation class.""" def __init__(self): super().__init__() def forward(self, x): return torch.clamp(F.softplus(x), min=1e-4, max=1e4)