Source code for dance.transforms.spatial_feature

import cv2
import numpy as np
import torch
import torchvision as tv
from sklearn.decomposition import PCA
from tqdm import tqdm, trange

from dance.transforms.base import BaseTransform
from dance.typing import Optional, Sequence
from dance.utils.matrix import normalize


[docs]class MorphologyFeature(BaseTransform): _DISPLAY_ATTRS = ("model_name", "n_components", "crop_size", "target_size") _MODELS = ("resnet50", "inception_v3", "xception", "vgg16") def __init__(self, *, model_name: str = "resnet50", n_components: int = 50, random_state: int = 0, crop_size: int = 20, target_size: int = 299, device: str = "cpu", channels: Sequence[str] = ("spatial_pixel", "image"), channel_types: Sequence[str] = ("obsm", "uns"), **kwargs): super().__init__(**kwargs) self.model_name = model_name self.n_components = n_components self.random_state = random_state self.crop_size = crop_size self.target_size = target_size self.device = device self.channels = channels self.channel_types = channel_types self.mean = np.array([0.406, 0.485, 0.456]) self.std = np.array([0.225, 0.229, 0.224]) if self.model_name not in self._MODELS: raise ValueError(f"Unsupported model {self.model_name!r}, available options are: {self._MODELS}") self.model = getattr(tv.models, self.model_name)(pretrained=True) self.model.fc = torch.nn.Sequential() self.model = self.model.to(self.device) def _crop_and_process(self, image, x, y): cs = self.crop_size ts = self.target_size img = image[x - cs:x + cs, y - cs:y + cs, :] img = cv2.resize(img, (ts, ts)) img = (img - self.mean) / self.std img = img.transpose((2, 0, 1)) img = torch.FloatTensor(img).unsqueeze(0) return img def __call__(self, data): xy_pixel = data.get_feature(return_type="numpy", channel=self.channels[0], channel_type=self.channel_types[0]) image = data.get_feature(return_type="numpy", channel=self.channels[1], channel_type=self.channel_types[1]) # TODO: improve computational efficiency by processing images in batch. features = [] for x, y in tqdm(xy_pixel, desc="Extracting feature", bar_format="{l_bar}{bar} [ time left: {remaining} ]"): img = self._crop_and_process(image, x, y).to(self.device) feature = self.model(img).view(-1).detach().cpu().numpy() features.append(feature) morth_feat = np.array(features) if self.n_components > 0: pca = PCA(n_components=self.n_components, random_state=self.random_state) morth_feat = pca.fit_transform(morth_feat) data.data.obsm[self.out] = morth_feat
[docs]class SMEFeature(BaseTransform): def __init__(self, n_neighbors: int = 3, n_components: int = 50, random_state: int = 0, *, channels: Sequence[Optional[str]] = (None, "SMEGraph"), channel_types: Sequence[Optional[str]] = (None, "obsp"), **kwargs): super().__init__(**kwargs) self.n_neighbors = n_neighbors self.n_components = n_components self.random_state = random_state self.channels = channels self.channel_types = channel_types def __call__(self, data): x = data.get_feature(return_type="numpy", channel=self.channels[0], channel_type=self.channel_types[0]) adj = data.get_feature(return_type="numpy", channel=self.channels[1], channel_type=self.channel_types[1]) imputed = [] num_samples, num_genes = x.shape for i in trange(num_samples, desc="Adjusting data", bar_format="{l_bar}{bar} [ time left: {remaining} ]"): weights = adj[i] nbrs_idx = weights.argsort()[-self.n_neighbors:] nbrs_weights = weights[nbrs_idx] if nbrs_weights.sum() > 0: nbrs_weights_scaled = (nbrs_weights / nbrs_weights.sum()) aggregated = (nbrs_weights_scaled[:, None] * x[nbrs_idx]).sum(0) else: aggregated = x[i] imputed.append(aggregated) sme_feat = (x + np.array(imputed)) / 2 if self.n_components > 0: sme_feat = normalize(sme_feat, mode="standardize", axis=0) pca = PCA(n_components=self.n_components, random_state=self.random_state) sme_feat = pca.fit_transform(sme_feat) data.data.obsm[self.out] = sme_feat