Source code for dance.modules.single_modality.clustering.scdcc

"""Reimplementation of scDCC.

Extended from https://github.com/ttgump/scDCC

Reference
----------
Tian, Tian, et al. "Model-based deep embedding for constrained clustering analysis of single cell RNA-seq data."
Nature communications 12.1 (2021): 1-12.

"""
import math

import numpy as np
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.cluster import KMeans
from torch.nn import Parameter
from torch.utils.data import DataLoader, TensorDataset

from dance import logger
from dance.modules.base import BaseClusteringMethod, TorchNNPretrain
from dance.transforms import AnnDataTransform, Compose, SaveRaw, SetConfig
from dance.typing import Any, List, LogLevel, Optional, Tuple
from dance.utils import get_device
from dance.utils.loss import ZINBLoss


def buildNetwork(layers: List[int], network_type: str, activation: str = "relu"):
    """Build network layer.

    Parameters
    ----------
    layers
        Dimensions of layers.
    network_type
        Type of network.
    activation
        Activation function.


    Returns
    -------
    Built network.

    """
    net = []
    for i in range(1, len(layers)):
        net.append(nn.Linear(layers[i - 1], layers[i]))
        if activation == "relu":
            net.append(nn.ReLU())
        elif activation == "sigmoid":
            net.append(nn.Sigmoid())
    net = nn.Sequential(*net)
    return net


[docs]class ScDCC(nn.Module, TorchNNPretrain, BaseClusteringMethod): """ScDCC class. Parameters ---------- input_dim Dimension of encoder input. z_dim Dimension of embedding. n_clusters Number of clusters. encodeLayer Dimensions of encoder layers. decodeLayer Dimensions of decoder layers. activation Activation function. sigma Parameter of Gaussian noise. alpha Parameter of soft assign. gamma Parameter of cluster loss. ml_weight Parameter of must-link loss. cl_weight Parameter of cannot-link loss. device Computation device. """ def __init__( self, input_dim: int, z_dim: int, n_clusters: int, encodeLayer: List[int], decodeLayer: List[int], activation: str = "relu", sigma: float = 1., alpha: float = 1., gamma: float = 1., ml_weight: float = 1., cl_weight: float = 1., device: str = "auto", pretrain_path: Optional[str] = None, ): super().__init__() self.z_dim = z_dim self.n_clusters = n_clusters self.activation = activation self.sigma = sigma self.alpha = alpha self.gamma = gamma self.ml_weight = ml_weight self.cl_weight = cl_weight self.device = get_device(device) self.pretrain_path = pretrain_path self.encoder = buildNetwork([input_dim] + encodeLayer, network_type="encode", activation=activation) self.decoder = buildNetwork([z_dim] + decodeLayer, network_type="decode", activation=activation) self._enc_mu = nn.Linear(encodeLayer[-1], z_dim) self._dec_mean = nn.Sequential(nn.Linear(decodeLayer[-1], input_dim), MeanAct()) self._dec_disp = nn.Sequential(nn.Linear(decodeLayer[-1], input_dim), DispAct()) self._dec_pi = nn.Sequential(nn.Linear(decodeLayer[-1], input_dim), nn.Sigmoid()) self.mu = Parameter(torch.Tensor(n_clusters, z_dim)) self.zinb_loss = ZINBLoss().to(self.device) self.to(self.device) @staticmethod def preprocessing_pipeline(log_level: LogLevel = "INFO"): return Compose( AnnDataTransform(sc.pp.filter_genes, min_counts=1), AnnDataTransform(sc.pp.filter_cells, min_counts=1), SaveRaw(), AnnDataTransform(sc.pp.normalize_total), AnnDataTransform(sc.pp.log1p), AnnDataTransform(sc.pp.scale), SetConfig({ "feature_channel": [None, None, "n_counts"], "feature_channel_type": ["X", "raw_X", "obs"], "label_channel": "Group" }), log_level=log_level, )
[docs] def soft_assign(self, z): """Soft assign q with z. Parameters ---------- z Embedding. Returns ------- q Soft label. """ q = 1.0 / (1.0 + torch.sum((z.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha) q = q**((self.alpha + 1.0) / 2.0) q = (q.t() / torch.sum(q, dim=1)).t() return q
[docs] def target_distribution(self, q): """Calculate auxiliary target distribution p with q. Parameters ---------- q Soft label. Returns ------- p Target distribution. """ p = q**2 / q.sum(0) return (p.t() / p.sum(1)).t()
[docs] def forward(self, x): """Forward propagation. Parameters ---------- x Input features. Returns ------- z0 Embedding. q Soft label. _mean Data mean from ZINB. _disp Data dispersion from ZINB. _pi Data dropout probability from ZINB. """ h = self.encoder(x + torch.randn_like(x) * self.sigma) z = self._enc_mu(h) h = self.decoder(z) _mean = self._dec_mean(h) _disp = self._dec_disp(h) _pi = self._dec_pi(h) h0 = self.encoder(x) z0 = self._enc_mu(h0) q = self.soft_assign(z0) return z0, q, _mean, _disp, _pi
[docs] def encodeBatch(self, X, batch_size=256): """Batch encoder. Parameters ---------- X Input features. batch_size Size of batch. Returns ------- Embedding. """ encoded = [] num = X.shape[0] num_batch = int(math.ceil(1.0 * X.shape[0] / batch_size)) for batch_idx in range(num_batch): xbatch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] inputs = xbatch z, _, _, _, _ = self.forward(inputs) encoded.append(z.data) encoded = torch.cat(encoded, dim=0) return encoded
[docs] def cluster_loss(self, p, q): """Calculate cluster loss. Parameters ---------- p Target distribution. q Soft label. Returns ------- Cluster loss. """ def kld(target, pred): return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=-1)) kldloss = kld(p, q) loss = self.gamma * kldloss return loss
[docs] def pairwise_loss(self, p1, p2, cons_type): """Calculate pairwise loss. Parameters ---------- p1 Distribution 1. p2 Distribution 2. cons_type Type of loss. Returns ------- Pairwise loss. """ if cons_type == "ML": ml_loss = torch.mean(-torch.log(torch.sum(p1 * p2, dim=1))) loss = self.ml_weight * ml_loss return loss else: cl_loss = torch.mean(-torch.log(1.0 - torch.sum(p1 * p2, dim=1))) loss = self.cl_weight * cl_loss return loss
[docs] def pretrain(self, x, X_raw, n_counts, batch_size=256, lr=0.001, epochs=400): """Pretrain autoencoder. Parameters ---------- x Input features. X_raw Raw input features. n_counts Total counts for each cell. batch_size Size of batch. lr Learning rate. epochs Number of epochs. """ size_factor = torch.tensor(n_counts / np.median(n_counts)) dataset = TensorDataset(torch.Tensor(x), torch.Tensor(X_raw), torch.Tensor(size_factor)) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, amsgrad=True) for epoch in range(epochs): for batch_idx, (x_batch, x_raw_batch, sf_batch) in enumerate(dataloader): x_tensor = x_batch.to(self.device) x_raw_tensor = x_raw_batch.to(self.device) sf_tensor = sf_batch.to(self.device) _, _, mean_tensor, disp_tensor, pi_tensor = self.forward(x_tensor) loss = self.zinb_loss(x=x_raw_tensor, mean=mean_tensor, disp=disp_tensor, pi=pi_tensor, scale_factor=sf_tensor) optimizer.zero_grad() loss.backward() optimizer.step() if epoch % 100 == 0: # logger.info("Pretrain epoch [%2d/%3d], ZINB loss: %.4f", batch_idx + 1, epoch + 1, loss.item()) logger.info("Pretrain epoch [%3d], ZINB loss: %.4f", epoch + 1, loss.item())
[docs] def fit( self, inputs: Tuple[np.ndarray, np.ndarray, np.ndarray], y: np.ndarray = None, ml_ind1: np.ndarray = np.array([]), ml_ind2: np.ndarray = np.array([]), cl_ind1: np.ndarray = np.array([]), cl_ind2: np.ndarray = np.array([]), ml_p: float = 1., cl_p: float = 1., lr: float = 1., batch_size: int = 256, epochs: int = 10, update_interval: int = 1, tol: float = 1e-3, pt_batch_size: int = 256, pt_lr: float = 0.001, pt_epochs: int = 400, ): """Train model. Parameters ---------- inputs A tuple containing (1) the input features, (2) the raw input features, and (3) the total counts per cell. y True label. Used for model selection. ml_ind1 Index 1 of must-link pairs. ml_ind2 Index 2 of must-link pairs. cl_ind1 Index 1 of cannot-link pairs. cl_ind2 Index 2 of cannot-link pairs. ml_p Parameter of must-link loss. cl_p Parameter of cannot-link loss. lr Learning rate. batch_size Size of batch. epochs Number of epochs. update_interval Update interval of soft label and target distribution. tol Tolerance for training loss. pt_batch_size Pretrain batch size. pt_lr Pretrain learning rate. pt_epochs Pretrain epochs. """ X, X_raw, n_counts = inputs self._pretrain(X, X_raw, n_counts, batch_size=pt_batch_size, lr=pt_lr, epochs=pt_epochs, force_pretrain=True) X = torch.tensor(X).to(self.device) X_raw = torch.tensor(X_raw).to(self.device) sf = torch.tensor(n_counts / np.median(n_counts)).to(self.device) optimizer = optim.Adadelta(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, rho=.95) # Initializing cluster centers with kmeans kmeans = KMeans(self.n_clusters, n_init=20) data = self.encodeBatch(X) self.y_pred = kmeans.fit_predict(data.data.cpu().numpy()) self.y_pred_last = self.y_pred self.mu.data.copy_(torch.Tensor(kmeans.cluster_centers_)) self.train() num = X.shape[0] num_batch = int(math.ceil(1.0 * X.shape[0] / batch_size)) ml_num_batch = int(math.ceil(1.0 * ml_ind1.shape[0] / batch_size)) cl_num_batch = int(math.ceil(1.0 * cl_ind1.shape[0] / batch_size)) cl_num = cl_ind1.shape[0] ml_num = ml_ind1.shape[0] update_ml = 1 update_cl = 1 aris = [] P = {} Q = {} Z = {} delta_label = np.inf for epoch in range(epochs): if epoch % update_interval == 0: # update the targe distribution p latent = self.encodeBatch(X) q = self.soft_assign(latent) self.q = q p = self.target_distribution(q).data self.y_pred = self.predict() p_ = {f"epoch{epoch}": p} q_ = {f"epoch{epoch}": q} z_ = {f"epoch{epoch}": latent.data} P = {**P, **p_} Q = {**Q, **q_} Z = {**Z, **z_} # check stop criterion if False: delta_label = np.sum(self.y_pred != self.y_pred_last).astype(np.float32) / num self.y_pred_last = self.y_pred if epoch > 0 and delta_label < tol: logger.info("Reach tolerance threshold (%.3e < %.3e). Stopping training.", delta_label, tol) break # calculate ari score for model selection ari = self.score(None, y) aris.append(ari) # train 1 epoch for clustering loss train_loss = 0.0 recon_loss_val = 0.0 cluster_loss_val = 0.0 for batch_idx in range(num_batch): xbatch = X[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] xrawbatch = X_raw[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] sfbatch = sf[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] pbatch = p[batch_idx * batch_size:min((batch_idx + 1) * batch_size, num)] optimizer.zero_grad() inputs = xbatch rawinputs = xrawbatch sfinputs = sfbatch target = pbatch z, qbatch, meanbatch, dispbatch, pibatch = self.forward(inputs) cluster_loss = self.cluster_loss(target, qbatch) recon_loss = self.zinb_loss(rawinputs, meanbatch, dispbatch, pibatch, sfinputs) loss = cluster_loss + recon_loss loss.backward() optimizer.step() cluster_loss_val += cluster_loss.data * len(inputs) recon_loss_val += recon_loss.data * len(inputs) train_loss = cluster_loss_val + recon_loss_val if epoch % 50 == 0: logger.info("#Epoch %3d: Total: %.4f, Clustering Loss: %.4f, ZINB Loss: %.4f", epoch + 1, train_loss / num, cluster_loss_val / num, recon_loss_val / num) ml_loss = 0.0 if epoch % update_ml == 0: for ml_batch_idx in range(ml_num_batch): px1 = X[ml_ind1[ml_batch_idx * batch_size:min(ml_num, (ml_batch_idx + 1) * batch_size)]] pxraw1 = X_raw[ml_ind1[ml_batch_idx * batch_size:min(ml_num, (ml_batch_idx + 1) * batch_size)]] sf1 = sf[ml_ind1[ml_batch_idx * batch_size:min(ml_num, (ml_batch_idx + 1) * batch_size)]] px2 = X[ml_ind2[ml_batch_idx * batch_size:min(ml_num, (ml_batch_idx + 1) * batch_size)]] sf2 = sf[ml_ind2[ml_batch_idx * batch_size:min(ml_num, (ml_batch_idx + 1) * batch_size)]] pxraw2 = X_raw[ml_ind2[ml_batch_idx * batch_size:min(ml_num, (ml_batch_idx + 1) * batch_size)]] optimizer.zero_grad() inputs1 = px1 rawinputs1 = pxraw1 sfinput1 = sf1 inputs2 = px2 rawinputs2 = pxraw2 sfinput2 = sf2 z1, q1, mean1, disp1, pi1 = self.forward(inputs1) z2, q2, mean2, disp2, pi2 = self.forward(inputs2) loss = (ml_p * self.pairwise_loss(q1, q2, "ML") + self.zinb_loss(rawinputs1, mean1, disp1, pi1, sfinput1) + self.zinb_loss(rawinputs2, mean2, disp2, pi2, sfinput2)) # 0.1 for mnist/reuters, 1 for fashion, the parameters are tuned via grid search on validation set ml_loss += loss.data loss.backward() optimizer.step() cl_loss = 0.0 if epoch % update_cl == 0: for cl_batch_idx in range(cl_num_batch): px1 = X[cl_ind1[cl_batch_idx * batch_size:min(cl_num, (cl_batch_idx + 1) * batch_size)]] px2 = X[cl_ind2[cl_batch_idx * batch_size:min(cl_num, (cl_batch_idx + 1) * batch_size)]] optimizer.zero_grad() inputs1 = px1 inputs2 = px2 z1, q1, _, _, _ = self.forward(inputs1) z2, q2, _, _, _ = self.forward(inputs2) loss = cl_p * self.pairwise_loss(q1, q2, "CL") cl_loss += loss.data loss.backward() optimizer.step() if ml_num_batch > 0 and cl_num_batch > 0: if epoch % 50 == 0: logger.info("Pairwise Total: %.4f, ML loss: %.4f, CL loss: %.4f", float(ml_loss.cpu()) + float(cl_loss.cpu()), ml_loss.cpu(), cl_loss.cpu()) index = update_interval * np.argmax(aris) self.q = Q[f"epoch{index}"] self.z = Z[f"epoch{index}"]
[docs] def predict_proba(self, x: Optional[Any] = None) -> np.ndarray: """Get the predicted propabilities for each cell. Parameters ---------- x Not used, for compatibility with the BaseClusteringMethod class. Returns ------- pred_prop Predicted probability for each cell. """ pred_prob = self.q.detach().clone().cpu().numpy() return pred_prob
[docs] def predict(self, x: Optional[Any] = None) -> np.ndarray: """Get predictions from the trained model. Parameters ---------- x Not used, for compatibility with the BaseClusteringMethod class. Returns ------- pred Predicted clustering assignment for each cell. """ pred = self.predict_proba().argmax(1) return pred
[docs] def get_latent(self) -> torch.Tensor: """Get the latent representation of the input data. Returns ------- z Latent representation of the input data. """ return self.z
class MeanAct(nn.Module): """Mean activation class.""" def __init__(self): super().__init__() def forward(self, x): return torch.clamp(torch.exp(x), min=1e-5, max=1e6) class DispAct(nn.Module): """Dispersion activation class.""" def __init__(self): super().__init__() def forward(self, x): return torch.clamp(F.softplus(x), min=1e-4, max=1e4)