"""Reimplementation of SpaGCN.
Extended from https://github.com/jianhuupenn/SpaGCN
Reference
----------
Hu, Jian, et al. "SpaGCN: Integrating gene expression, spatial location and histology to identify spatial domains and
spatially variable genes by graph convolutional network." Nature methods 18.11 (2021): 1342-1351.
"""
import numpy as np
import pandas as pd
import scanpy as sc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.cluster import KMeans
from torch.nn.parameter import Parameter
from dance import logger
from dance.modules.base import BaseClusteringMethod
from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import SpaGCNGraph, SpaGCNGraph2D
from dance.typing import LogLevel
def calculate_p(adj, l):
adj_exp = np.exp(-1 * (adj**2) / (2 * (l**2)))
return np.mean(np.sum(adj_exp, 1)) - 1
def search_l(p, adj, start=0.01, end=1000, tol=0.01, max_run=100):
run = 0
p_low = calculate_p(adj, start)
p_high = calculate_p(adj, end)
if p_low > p + tol:
logger.info("l not found, try smaller start point.")
return None
elif p_high < p - tol:
logger.info("l not found, try bigger end point.")
return None
elif np.abs(p_low - p) <= tol:
logger.info(f"recommended l: {start}")
return start
elif np.abs(p_high - p) <= tol:
logger.info(f"recommended l: {end}")
return end
while (p_low + tol) < p < (p_high - tol):
run += 1
logger.info(f"Run {run}: l [{start}, {end}], p [{p_low}, {p_high}]")
if run > max_run:
logger.info(f"Exact l not found, closest values are:\nl={start}: p={p_low}\nl={end}: p={p_high}")
return None
mid = (start + end) / 2
p_mid = calculate_p(adj, mid)
if np.abs(p_mid - p) <= tol:
logger.info(f"recommended l: {mid}")
return mid
if p_mid <= p:
start = mid
p_low = p_mid
else:
end = mid
p_high = p_mid
return None
def refine(sample_id, pred, dis, shape="hexagon"):
"""An optional refinement step for the clustering result. In this step, SpaGCN
examines the domain assignment of each spot and its surrounding spots. For a given
spot, if more than half of its surrounding spots are assigned to a different domain,
this spot will be relabeled to the same domain as the major label of its surrounding
spots.
Parameters
----------
sample_id
Sample id.
pred
Initial prediction.
dis
Graph structure.
shape : str
Shape parameter.
Returns
-------
refined_pred
Refined predictions.
"""
refined_pred = []
pred = pd.DataFrame({"pred": pred}, index=sample_id)
dis_df = pd.DataFrame(dis, index=sample_id, columns=sample_id)
if shape == "hexagon":
num_nbs = 6
elif shape == "square":
num_nbs = 4
else:
logger.info("Shape not recognized, shape='hexagon' for Visium data, 'square' for ST data.")
for i in range(len(sample_id)):
index = sample_id[i]
dis_tmp = dis_df.loc[index, :].sort_values()
nbs = dis_tmp[0:num_nbs + 1]
nbs_pred = pred.loc[nbs.index, "pred"]
self_pred = pred.loc[index, "pred"]
v_c = nbs_pred.value_counts()
if (v_c.loc[self_pred] < num_nbs / 2) and (np.max(v_c) > num_nbs / 2):
refined_pred.append(v_c.idxmax())
else:
refined_pred.append(self_pred)
return refined_pred
class GraphConvolution(nn.Module):
"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907."""
def __init__(self, in_features, out_features, bias=True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / np.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output
def __repr__(self):
return f"{self.__class__.__name__}({self.in_features} -> {self.out_features})"
class SimpleGCDEC(nn.Module):
"""Basic model used in SpaGCN training.
Parameters
----------
nfeat : int
Input feature dimension.
nhid : int
Output feature dimension.
alpha : float optional
Alphat parameter.
"""
def __init__(self, nfeat, nhid, alpha=0.2, device="cpu"):
super().__init__()
self.gc = GraphConvolution(nfeat, nhid)
self.nhid = nhid
# self.mu is determined by the init method
self.alpha = alpha
self.device = device
def forward(self, x, adj):
"""Forward function."""
x = self.gc(x, adj)
q = 1.0 / ((1.0 + torch.sum((x.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha) + 1e-8)
q = q**(self.alpha + 1.0) / 2.0
q = q / torch.sum(q, dim=1, keepdim=True)
return x, q
def loss_function(self, p, q):
"""Objective function as a Kullback–Leibler (KL) divergence loss."""
def kld(target, pred):
return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=1))
loss = kld(p, q)
return loss
def target_distribution(self, q):
"""Generate an auxiliary target distribution based on q the probability of
assigning cell i to cluster j.
Parameters
----------
q
The probability of assigning cell i to cluster j.
Returns
-------
p
Target distribution.
"""
p = q**2 / torch.sum(q, dim=0)
p = p / torch.sum(p, dim=1, keepdim=True)
return p
def fit(self, X, adj, lr=0.001, epochs=5000, update_interval=3, trajectory_interval=50, weight_decay=5e-4,
opt="sgd", init="louvain", n_neighbors=10, res=0.4, n_clusters=10, init_spa=True, tol=1e-3):
"""Fit function for model training.
Parameters
----------
X
Node features.
adj
Adjacent matrix.
lr : float
Learning rate.
epochs : int
Maximum number of epochs.
update_interval : int
Interval for update.
trajectory_interval: int
Trajectory interval.
weight_decay : float
Weight decay.
opt : str
Optimizer.
init : str
"louvain" or "kmeans".
n_neighbors : int
The number of neighbors used in louvain.
res : float
Used for louvain.
n_clusters : int
The number of clusters usedd in kmeans.
init_spa : bool
Initialize spatial.
tol : float
Tolerant value for searching l.
"""
self.trajectory = []
if opt == "sgd":
optimizer = optim.SGD(self.parameters(), lr=lr, momentum=0.9)
elif opt == "admin":
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
features = self.gc(torch.FloatTensor(X), torch.FloatTensor(adj))
# ---------------------------------------------------------------------
if init == "kmeans":
logger.info("Initializing cluster centers with kmeans, n_clusters known")
self.n_clusters = n_clusters
kmeans = KMeans(self.n_clusters, n_init=20)
if init_spa:
# Kmeans using both expression and spatial information
y_pred = kmeans.fit_predict(features.detach().numpy())
else:
# Kmeans using only expression information
y_pred = kmeans.fit_predict(X) # use X as numpy
elif init == "louvain":
logger.info(f"Initializing cluster centers with louvain, resolution = {res}")
if init_spa:
adata = sc.AnnData(features.detach().numpy())
else:
adata = sc.AnnData(X)
sc.pp.neighbors(adata, n_neighbors=n_neighbors)
# sc.tl.louvain(adata,resolution=res)
sc.tl.leiden(adata, resolution=res, key_added="louvain")
y_pred = adata.obs["louvain"].astype(int).to_numpy()
self.n_clusters = len(np.unique(y_pred))
# ---------------------------------------------------------------------
y_pred_last = y_pred
self.mu = Parameter(torch.Tensor(self.n_clusters, self.nhid))
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
self.trajectory.append(y_pred)
features = pd.DataFrame(features.detach().numpy(), index=np.arange(0, features.shape[0]))
Group = pd.Series(y_pred, index=np.arange(0, features.shape[0]), name="Group")
Mergefeature = pd.concat([features, Group], axis=1)
cluster_centers = np.asarray(Mergefeature.groupby("Group").mean())
self.mu.data.copy_(torch.Tensor(cluster_centers))
# Copy data and model in cuda
device = self.device
self = self.to(device)
X = X.to(device)
adj = adj.to(device)
self.train()
for epoch in range(epochs):
if epoch % update_interval == 0:
_, q = self.forward(X, adj)
p = self.target_distribution(q).data
if epoch % 10 == 0:
logger.info(f"Epoch {epoch}")
optimizer.zero_grad()
z, q = self(X, adj)
loss = self.loss_function(p, q)
loss.backward()
optimizer.step()
if epoch % trajectory_interval == 0:
self.trajectory.append(torch.argmax(q, dim=1).data.cpu().numpy())
# Check stop criterion
y_pred = torch.argmax(q, dim=1).data.detach().cpu().numpy()
delta_label = np.sum(y_pred != y_pred_last).astype(np.float32) / X.shape[0]
y_pred_last = y_pred
if epoch > 0 and (epoch - 1) % update_interval == 0 and delta_label < tol:
logger.info(f"delta_label {delta_label} < tol {tol}")
logger.info("Reach tolerance threshold. Stopping training.")
logger.info(f"Total epoch: {epoch}")
break
# Recover model and data in cpu
self = self.cpu()
X = X.cpu()
adj = adj.cpu()
def fit_with_init(self, X, adj, init_y, lr=0.001, epochs=5000, update_interval=1, weight_decay=5e-4, opt="sgd"):
"""Initializing cluster centers with kmeans."""
logger.info("Initializing cluster centers with kmeans.")
if opt == "sgd":
optimizer = optim.SGD(self.parameters(), lr=lr, momentum=0.9)
elif opt == "admin":
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
features, _ = self.forward(X, adj)
features = pd.DataFrame(features.detach().numpy(), index=np.arange(0, features.shape[0]))
Group = pd.Series(init_y, index=np.arange(0, features.shape[0]), name="Group")
Mergefeature = pd.concat([features, Group], axis=1)
cluster_centers = np.asarray(Mergefeature.groupby("Group").mean())
self.mu.data.copy_(torch.Tensor(cluster_centers))
# Copy data and model in cuda
device = self.device
self = self.to(device)
X = X.to(device)
adj = adj.to(device)
self.train()
for epoch in range(epochs):
if epoch % update_interval == 0:
_, q = self.forward(torch.FloatTensor(X), torch.FloatTensor(adj))
p = self.target_distribution(q).data
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
optimizer.zero_grad()
z, q = self(X, adj)
loss = self.loss_function(p, q)
loss.backward()
optimizer.step()
# Recover model and data in cpu
self = self.cpu()
X = X.cpu()
adj = adj.cpu()
def predict(self, X, adj):
"""Transform to float tensor."""
z, q = self(torch.FloatTensor(X), torch.FloatTensor(adj))
return z, q
# Not used in SpaGCN
class GC_DEC(nn.Module):
def __init__(self, nfeat, nhid1, nhid2, n_clusters=None, dropout=0.5, alpha=0.2):
super().__init__()
self.gc1 = GraphConvolution(nfeat, nhid1)
self.gc2 = GraphConvolution(nhid1, nhid2)
self.dropout = dropout
self.mu = Parameter(torch.Tensor(n_clusters, nhid2))
self.n_clusters = n_clusters
self.alpha = alpha
def forward(self, x, adj):
x = self.gc1(x, adj)
x = F.relu(x)
x = F.dropout(x, self.dropout, training=True)
x = self.gc2(x, adj)
q = 1.0 / ((1.0 + torch.sum((x.unsqueeze(1) - self.mu)**2, dim=2) / self.alpha) + 1e-6)
q = q**(self.alpha + 1.0) / 2.0
q = q / torch.sum(q, dim=1, keepdim=True)
return x, q
def loss_function(self, p, q):
def kld(target, pred):
return torch.mean(torch.sum(target * torch.log(target / (pred + 1e-6)), dim=1))
loss = kld(p, q)
return loss
def target_distribution(self, q):
p = q**2 / torch.sum(q, dim=0)
p = p / torch.sum(p, dim=1, keepdim=True)
return p
def fit(self, X, adj, lr=0.001, epochs=10, update_interval=5, weight_decay=5e-4, opt="sgd", init="louvain",
n_neighbors=10, res=0.4):
self.trajectory = []
logger.info("Initializing cluster centers with kmeans.")
if opt == "sgd":
optimizer = optim.SGD(self.parameters(), lr=lr, momentum=0.9)
elif opt == "admin":
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
features, _ = self.forward(torch.FloatTensor(X), torch.FloatTensor(adj))
# ---------------------------------------------------------------------
if init == "kmeans":
# Kmeans using only expression information
kmeans = KMeans(self.n_clusters, n_init=20)
y_pred = kmeans.fit_predict(features.detach().numpy())
elif init == "louvain":
# Louvain using only expression information
adata = sc.AnnData(features.detach().numpy())
sc.pp.neighbors(adata, n_neighbors=n_neighbors)
sc.tl.louvain(adata, resolution=res)
y_pred = adata.obs["louvain"].astype(int).to_numpy()
# ---------------------------------------------------------------------
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
self.trajectory.append(y_pred)
features = pd.DataFrame(features.detach().numpy(), index=np.arange(0, features.shape[0]))
Group = pd.Series(y_pred, index=np.arange(0, features.shape[0]), name="Group")
Mergefeature = pd.concat([features, Group], axis=1)
cluster_centers = np.asarray(Mergefeature.groupby("Group").mean())
self.mu.data.copy_(torch.Tensor(cluster_centers))
self.train()
for epoch in range(epochs):
if epoch % update_interval == 0:
_, q = self.forward(X, adj)
p = self.target_distribution(q).data
if epoch % 100 == 0:
logger.info(f"Epoch {epoch}")
optimizer.zero_grad()
z, q = self(X, adj)
loss = self.loss_function(p, q)
loss.backward()
optimizer.step()
self.trajectory.append(torch.argmax(q, dim=1).data.cpu().numpy())
def fit_with_init(self, X, adj, init_y, lr=0.001, epochs=10, update_interval=1, weight_decay=5e-4, opt="sgd"):
logger.info("Initializing cluster centers with kmeans.")
if opt == "sgd":
optimizer = optim.SGD(self.parameters(), lr=lr, momentum=0.9)
elif opt == "admin":
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
features, _ = self.forward(X, adj)
features = pd.DataFrame(features.detach().numpy(), index=np.arange(0, features.shape[0]))
Group = pd.Series(init_y, index=np.arange(0, features.shape[0]), name="Group")
Mergefeature = pd.concat([features, Group], axis=1)
cluster_centers = np.asarray(Mergefeature.groupby("Group").mean())
self.mu.data.copy_(torch.Tensor(cluster_centers))
self.train()
for epoch in range(epochs):
if epoch % update_interval == 0:
_, q = self.forward(torch.FloatTensor(X), torch.FloatTensor(adj))
p = self.target_distribution(q).data
X = torch.FloatTensor(X)
adj = torch.FloatTensor(adj)
optimizer.zero_grad()
z, q = self(X, adj)
loss = self.loss_function(p, q)
loss.backward()
optimizer.step()
def predict(self, X, adj):
z, q = self(torch.FloatTensor(X), torch.FloatTensor(adj))
return z, q
[docs]class SpaGCN(BaseClusteringMethod):
"""SpaGCN class.
Parameters
----------
l : float
the parameter to control percentage p
"""
def __init__(self, l=None, device="cpu"):
self.l = l
self.res = None
self.device = device
@staticmethod
def preprocessing_pipeline(alpha: float = 1, beta: int = 49, dim: int = 50, log_level: LogLevel = "INFO"):
return Compose(
FilterGenesMatch(prefixes=["ERCC", "MT-"]),
AnnDataTransform(sc.pp.normalize_total, target_sum=1e4),
AnnDataTransform(sc.pp.log1p),
SpaGCNGraph(alpha=alpha, beta=beta),
SpaGCNGraph2D(),
CellPCA(n_components=dim),
SetConfig({
"feature_channel": ["CellPCA", "SpaGCNGraph", "SpaGCNGraph2D"],
"feature_channel_type": ["obsm", "obsp", "obsp"],
"label_channel": "label",
"label_channel_type": "obs"
}),
log_level=log_level,
)
[docs] def search_l(self, p, adj, start=0.01, end=1000, tol=0.01, max_run=100):
"""Search best l.
Parameters
----------
p : float
Percentage.
adj :
Adjacent matrix.
start : float
Starting value for searching l.
end : float
Ending value for searching l.
tol : float
Tolerant value for searching l.
max_run : int
Maximum number of runs.
Returns
-------
l : float
best l, the parameter to control percentage p.
"""
l = search_l(p, adj, start, end, tol, max_run)
return l
[docs] def set_l(self, l):
"""Set l.
Parameters
----------
l : float
The parameter to control percentage p.
"""
self.l = l
[docs] def search_set_res(self, x, l, target_num, start=0.4, step=0.1, tol=5e-3, lr=0.05, epochs=10, max_run=10):
"""Search for optimal resolution parameter."""
res = start
logger.info(f"Start at {res = :.4f}, {step = :.4f}")
clf = SpaGCN(l)
y_pred = clf.fit_predict(x, init_spa=True, init="louvain", res=res, tol=tol, lr=lr, epochs=epochs)
old_num = len(set(y_pred))
logger.info(f"Res = {res:.4f}, Num of clusters = {old_num}")
run = 0
while old_num != target_num:
old_sign = 1 if (old_num < target_num) else -1
clf = SpaGCN(l)
y_pred = clf.fit_predict(x, init_spa=True, init="louvain", res=res + step * old_sign, tol=tol, lr=lr,
epochs=epochs)
new_num = len(set(y_pred))
logger.info(f"Res = {res + step * old_sign:.3e}, Num of clusters = {new_num}")
if new_num == target_num:
res = res + step * old_sign
logger.info(f"recommended res = {res:.4f}")
return res
new_sign = 1 if (new_num < target_num) else -1
if new_sign == old_sign:
res = res + step * old_sign
logger.info(f"Res changed to {res}")
old_num = new_num
else:
step = step / 2
logger.info(f"Step changed to {step:.4f}")
if run > max_run:
logger.info(f"Exact resolution not found. Recommended res = {res:.4f}")
return res
run += 1
logger.info("Recommended res = {res:.4f}")
self.res = res
return res
def calc_adj_exp(self, adj: np.ndarray) -> np.ndarray:
adj_exp = np.exp(-1 * (adj**2) / (2 * (self.l**2)))
return adj_exp
[docs] def fit(self, x, y=None, *, num_pcs=50, lr=0.005, epochs=2000, weight_decay=0, opt="admin", init_spa=True,
init="louvain", n_neighbors=10, n_clusters=None, res=0.4, tol=1e-3):
"""Fit function for model training.
Parameters
----------
embed
Input data.
adj
Adjacent matrix.
num_pcs : int
The number of component used in PCA.
lr : float
Learning rate.
epochs : int
Maximum number of epochs.
weight_decay : float
Weight decay.
opt : str
Optimizer.
init_spa : bool
Initialize spatial.
init : str
"louvain" or "kmeans".
n_neighbors : int
The number of neighbors used by Louvain.
n_clusters : int
The number of clusters usedd by kmeans.
res : float
The resolution parameter used by Louvain.
tol : float
Oolerant value for searching l.
"""
embed, adj = x
self.num_pcs = num_pcs
self.res = res
self.lr = lr
self.epochs = epochs
self.weight_decay = weight_decay
self.opt = opt
self.init_spa = init_spa
self.init = init
self.n_neighbors = n_neighbors
self.n_clusters = n_clusters
self.res = res
self.tol = tol
if self.l is None:
raise ValueError("l should be set before fitting the model!")
self.model = SimpleGCDEC(embed.shape[1], embed.shape[1], device=self.device)
adj_exp = self.calc_adj_exp(adj)
self.model.fit(embed, adj_exp, lr=self.lr, epochs=self.epochs, weight_decay=self.weight_decay, opt=self.opt,
init_spa=self.init_spa, init=self.init, n_neighbors=self.n_neighbors, n_clusters=self.n_clusters,
res=self.res, tol=self.tol)
[docs] def predict_proba(self, x):
"""Prediction function.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The predicted labels and the predicted probabilities.
"""
embed, adj = x
adj_exp = self.calc_adj_exp(adj)
_, pred_prob = self.model.predict(embed, adj_exp)
return pred_prob
[docs] def predict(self, x):
"""Prediction function.
Returns
-------
Tuple[np.ndarray, np.ndarray]
The predicted labels and the predicted probabilities.
"""
pred_prob = self.predict_proba(x)
pred = torch.argmax(pred_prob, dim=1).data.cpu().numpy()
return pred