Multi modality tasks

Joint embedding

class dance.modules.multi_modality.joint_embedding.DCCA(layer_e_1, hidden1_1, Zdim_1, layer_d_1, hidden2_1, layer_e_2, hidden1_2, Zdim_2, layer_d_2, hidden2_2, args, ground_truth1, Type_1='NB', Type_2='Bernoulli', cycle=1, attention_loss='Eucli', droprate=0.1)[source]

DCCA class.

Parameters:
  • layer_e_1 (list[int]) – Hidden layer specification for encoder1. List the dimensions of each hidden layer sequentially.

  • hidden1_1 (int) – Hidden dimension for encoder1. It should be consistent with the last layer in layer_e_1.

  • Zdim_1 (int) – Latent space dimension for VAE1.

  • layer_d_1 (list[int]) – Hidden layer specification for decoder1. List the dimensions of each hidden layer sequentially.

  • hidden2_1 (int) – Hidden dimension for decoder1. It should be consistent with the last layer in layer_d_1.

  • layer_e_2 (int) – Hidden layer specification for encoder2. List the dimensions of each hidden layer sequentially.

  • hidden1_2 (int) – Hidden dimension for encoder2. It should be consistent with the last layer in layer_e_1.

  • Zdim_2 (int) – Latent space dimension for VAE2.

  • layer_d_2 (int) – Hidden layer specification for decoder2. List the dimensions of each hidden layer sequentially.

  • hidden2_2 (int) – Hidden dimension for decoder2. It should be consistent with the last layer in layer_d_1.

  • args (argparse.Namespace) – A Namespace object that contains arguments of DCCA. For details of parameters in parser args, please refer to link (parser help document).

  • ground_truth1 (torch.Tensor) – Extra labels for VAE1.

  • Type_1 (str optional) – Loss type for VAE1. Default: ‘NB’. By default to be ‘NB’.

  • Type_2 (str optional) – Loss type for VAE2. Default: ‘Bernoulli’. By default to be ‘Bernoulli’.

  • cycle (int optional) – Number of multiple training cycles. In each cycle iteratively update VAE1 and VAE2. By default to be 1.

  • attention_loss (str optional) – Loss type of attention loss. By default to be ‘Eucli’.

  • droprate (float optional) – Dropout rate for encoder/decoder layers. By default to be 0.1.

fit(train_loader, test_loader, total_loader, first='RNA')[source]

Fit function for training.

Parameters:
  • train_loader (torch.utils.data.DataLoader) – Dataloader for training dataset.

  • test_loader (torch.utils.data.DataLoader) – Dataloader for testing dataset.

  • total_loader (torch.utils.data.DataLoader) – Dataloader for both training and testing dataset, for extra evaluation purpose.

  • first (str) – Type of modality 1.

Returns:

Return type:

None.

forward(total_loader)[source]

Forward function for torch.nn.Module. An alias of encode_Batch function.

Parameters:

total_loader (torch.utils.data.DataLoader) – Dataloader for dataset.

Returns:

  • latent_z1 (numpy.ndarray) – Latent representation of modality 1.

  • latent_z2 (numpy.ndarray) – Latent representation of modality 2.

  • norm_x1 (numpy.ndarray) – Normalized representation of modality 1.

  • recon_x1 (numpy.ndarray) – Reconstruction result of modality 1.

  • norm_x2 (numpy.ndarray) – Normalized representation of modality 2.

  • recon_x2 (numpy.ndarray) – Reconstruction result of modality 2.

predict(total_loader)[source]

Predict function to get latent representation of data.

Parameters:

total_loader (torch.utils.data.DataLoader) – Dataloader for dataset.

Returns:

  • emb1 (numpy.ndarray) – Latent representation of modality 1.

  • emb2 (numpy.ndarray) – Latent representation of modality 2.

score(dataloader, metric='clustering')[source]

Score function to get score of prediction.

Parameters:

dataloader (torch.utils.data.DataLoader) – Dataloader for testing dataset.

Returns:

  • NMI_score1 (float) – Metric eval score for VAE1.

  • ARI_score1 (float) – Metric eval score for VAE1.

  • NMI_score2 (float) – Metric eval score for VAE2.

  • ARI_score2 (float) – Metric eval score for VAE2.

class dance.modules.multi_modality.joint_embedding.JAE(nb_cell_types, nb_batches, nb_phases, input_dimension)[source]
forward(x)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class dance.modules.multi_modality.joint_embedding.ScMoGCNWrapper(args, num_celL_types, num_batches, num_phases, num_features)[source]

ScMoGCN class.

Parameters:
fit(g_mod1, g_mod2, train_size, cell_type, batch_label, phase_score)[source]

Fit function for training.

Parameters:
  • g_mod1 (dgl.DGLGraph) – Bipartite expression feature graph for modality 1.

  • g_mod2 (dgl.DGLGraph) – Bipartite expression feature graph for modality 2.

  • train_size (int) – Number of training samples.

  • labels (torch.Tensor) – Labels for training samples.

  • cell_type (torch.Tensor) – Cell type labels for training samples.

  • batch_label (torch.Tensor) – Batch labels for training samples.

  • phase_score (torch.Tensor) – Phase labels for training samples.

Returns:

Return type:

None.

load(path, map_location=None)[source]

Load model parameters from checkpoint file.

Parameters:
  • path (str) – Path to the checkpoint file.

  • map_location (str optional) – Mapped device. This parameter will be passed to torch.load function if not none.

Returns:

Return type:

None.

predict(idx)[source]

Predict function to get latent representation of data.

Parameters:

idx (Iterable[int]) – Index of testing samples for prediction.

Returns:

prediction – Joint embedding of input data.

Return type:

torch.Tensor

score(idx, cell_type, phase_score=None, adata_sol=None, metric='loss')[source]

Score function to get score of prediction.

Parameters:
  • idx (Iterable[int]) – Index of testing samples for scoring.

  • cell_type (torch.Tensor) – Cell type labels of testing samples.

  • phase_score (torch.Tensor optional) – Cell cycle score of testing samples.

  • metric (str optional) – The type of evaluation metric, by default to be ‘loss’.

Returns:

  • loss1 (float) – Reconstruction loss.

  • loss2 (float) – Cell type classfication loss.

  • loss3 (float) – Batch regularization loss.

  • loss4 (float) – Cell cycle score loss.

to(device)[source]

Performs device conversion.

Parameters:

device (str) – Target device.

Returns:

self – Converted model.

Return type:

ScMoGCNWrapper

class dance.modules.multi_modality.joint_embedding.scMVAE(encoder_1, hidden_1, Z_DIMS, decoder_share, share_hidden, decoder_1, hidden_2, encoder_l, hidden3, encoder_2, hidden_4, encoder_l1, hidden3_1, decoder_2, hidden_5, drop_rate, log_variational=True, Type='Bernoulli', device='cpu', n_centroids=19, penality='GMM', model=2)[source]
fit(args, train, valid, final_rate, scale_factor, device)[source]

Fit function for training.

Parameters:
  • train (torch.utils.data.DataLoader) – Dataloader for training dataset.

  • valid (torch.utils.data.DataLoader) – Dataloader for testing dataset.

  • final_rate (torch.utils.data.DataLoader) – Dataloader for both training and testing dataset, for extra evaluation purpose.

  • scale_factor (str) – Type of modality 1.

  • device (torch.device) –

Returns:

Return type:

None.

forward(X1, X2, local_l_mean, local_l_var, local_l_mean1, local_l_var1)[source]

Forward function for torch.nn.Module. An alias of encode_Batch function.

Parameters:
  • X1 (torch.utils.data.DataLoader) – Dataloader for dataset.

  • X2

  • local_l_mean

  • local_l_var

Returns:

  • latent_z1 (numpy.ndarray) – Latent representation of modality 1.

  • latent_z2 (numpy.ndarray) – Latent representation of modality 2.

  • norm_x1 (numpy.ndarray) – Normalized representation of modality 1.

  • recon_x1 (numpy.ndarray) – Reconstruction result of modality 1.

  • norm_x2 (numpy.ndarray) – Normalized representation of modality 2.

  • recon_x2 (numpy.ndarray) – Reconstruction result of modality 2.

init_gmm_params(Dataloader)[source]

This function will initialize the parameters for PoE model.

Parameters:

Dataloader (torch.utils.data.DataLoader) – Dataloader for the whole dataset.

Returns:

Return type:

None.

predict(X1, X2, out='Z', device='cpu')[source]

Predict function to get prediction.

Parameters:
  • X1 (torch.Tensor) – Features of modality 1.

  • X2 (torch.Tensor) – Features of modality 2.

  • out (str optional) – The ground truth labels for evaluation.

Returns:

result – The requested result, by default to be embedding in latent space (a.k.a ‘Z’).

Return type:

torch.Tensor

score(X1, X2, labels, adata_sol=None, metric='clustering')[source]

Score function to get score of prediction.

Parameters:
  • X1 (torch.Tensor) – Features of modality 1.

  • X2 (torch.Tensor) – Features of modality 2.

  • labels (torch.Tensor) – The ground truth labels for evaluation.

Returns:

  • NMI_score (float) – NMI eval score.

  • ARI_score (float) – ARI eval score.

Modality matching

class dance.modules.multi_modality.match_modality.CMAE(hyperparameters)[source]

CMAE class.

Parameters:

hyperparameters (dictionary) – A dictionary that contains arguments of CMAE. For details of parameters in parser args, please refer to link (parser help document).

fit(train_mod1, train_mod2, aux_labels=None, checkpoint_directory='./checkpoint', val_ratio=0.15)[source]

Train CMAE.

Parameters:
  • train_mod1 (torch.Tensor) – Features of input modality.

  • train_mod2 (torch.Tensor) – Features of target modality.

  • aux_labels (torch.Tensor optional) – Auxiliary labels for extra supervision during training.

  • checkpoint_directory (str optional) – Path to the checkpoint file, by default to be ‘./checkpoint’.

  • val_ratio (float) – Ratio for automatic train-validation split.

forward(mod1, mod2)[source]

Forward function for torch.nn.Module.

Parameters:
  • mod1 (torch.Tensor) – Features of modality 1.

  • mod2 (torch.Tensor) – Features of modality 2.

Returns:

x_abtorch.Tensor

Prediction of target modality from input modality.

x_batorch.Tensor

Prediction of input modality from target modality.

predict(mod1, mod2, metric='l1')[source]

Predict function to get prediction of target modality features.

Parameters:
  • mod1 (torch.Tensor) – Features of modality 1.

  • mod2 (torch.Tensor) – Features of modality 2.

Returns:

pred – Joint embedding of input modalities.

Return type:

torch.Tensor

resume(checkpoint_dir)[source]

Resume function to resume from checkpoint file.

Parameters:

checkpoint_dir (str) – Path to the checkpoint file.

Returns:

iterations – Current iteration number of resumed model.

Return type:

int

save(snapshot_dir, iterations)[source]

Save function to save parameters to checkpoint file.

Parameters:
  • checkpoint_dir (str) – Path to the checkpoint file.

  • iterations (int) – Current number of training iterations.

score(mod1, mod2, labels)[source]

Score function to get score of prediction.

Parameters:
  • mod1 (torch.Tensor) – Features of modality 1.

  • mod2 (torch.Tensor) – Features of modality 2.

  • labels (torch.Tensor) – Ground truth mapping of modality 2.

Returns:

score – Matching accuracy.

Return type:

float

class dance.modules.multi_modality.match_modality.MMVAE(subtask, params)[source]

MMVAE class.

Parameters:
  • subtask (str) – Name of the subtask which is composed of the name of two modality. This parameter will indicate some modality-specific features in the model.

  • params (argparse.Namespace) – A Namespace object that contains arguments of MMVAE. For details of parameters in parser args, please refer to link (parser help document).

fit(x_train, y_train, val_ratio=0.15)[source]

Fit function for training.

Parameters:
  • x_train (torch.Tensor) – Input modality for training.

  • y_train (torch.Tensor) – Target modality for training.

  • val_ratio (float) – Ratio for automatic train-validation split.

forward(x)[source]

Forward function for torch.nn.Module.

Parameters:

x (list[torch.Tensor]) – Features of two modalities.

Returns:

  • qz_xs (list[torch.Tensor]) – Post prior of two modalities.

  • px_zs (list[torch.Tensor]) – likelihood of two modalities.

  • zss (list[torch.Tensor]) – Reconstruction results of two modalities.

predict(mod1, mod2, metric='minkowski')[source]

Predict function to get score of prediction.

Parameters:
  • mod1 (torch.Tensor) – Features of the first modality.

  • mod2 (torch.Tensor) – Features of the second modality.

  • metric (str optional) – Metric of the matching function, by default to be ‘minkowski’.

Returns:

pred – Predicted matching between two modalities.

Return type:

float

score(mod1, mod2, labels=None, metric='minkowski')[source]

Score function to get score of prediction.

Parameters:
  • mod1 (torch.Tensor) – Features of modality 1.

  • mod2 (torch.Tensor) – Features of modality 2.

  • labels (torch.Tensor optional) – Labels of matching modality, i.e. cell correspondence between two modalities. Required when metric is not ‘loss’.

  • metric (str optional) – Metric of the score function, by default to be ‘minkowski’.

Returns:

score – Score of predicted matching, according to specified metric.

Return type:

float

class dance.modules.multi_modality.match_modality.ScMoGCNWrapper(args, layers, temp=1)[source]

ScMoGCN class.

Parameters:
  • args (argparse.Namespace) – A Namespace object that contains arguments of ScMoGCN. For details of parameters in parser args, please refer to link (parser help document).

  • layers (List[List[Union[int, float]]]) – Specification of hidden layers.

  • temp (int optional) – Temperature for softmax, by default to be 1.

fit(g_mod1, g_mod2, labels1, labels2, train_size)[source]

Fit function for training.

Parameters:
  • g_mod1 (dgl.DGLGraph) – DGLGraph for modality 1.

  • g_mod2 (dgl.DGLGraph) – DGLGraph for modality 2.

  • labels1 (torch.Tensor) – Column-wise matching labels.

  • labels2 (torch.Tensor) – Row-wise matching labels.

  • train_size (int) – Number of training samples.

Returns:

Return type:

None.

load(path, map_location=None)[source]

Load model parameters from checkpoint file.

Parameters:
  • path (str) – Path to the checkpoint file.

  • map_location (str optional) – Mapped device. This parameter will be passed to torch.load function if not none.

Returns:

Return type:

None.

predict(idx, enhance=False, batch1=None, batch2=None, threshold_quantile=0.95)[source]

Predict function to get latent representation of data.

Parameters:
  • idx (Iterable[int]) – Cell indices for prediction.

  • enhance (bool optional) – Whether enable enhancement matching (e.g. bipartite matching), by default to be False.

  • batch1 (torch.Tensor optional) – Batch labels of modality 1, by default to be None.

  • batch2 (torch.Tensor optional) – Batch labels of modality 2, by default to be None.

  • threshold_quantile (float) – Parameter for batch_separated_bipartite_matching when enhance is set to true, which controls the sparsity.

Returns:

pred – Predicted matching matrix.

Return type:

torch.Tensor

score(idx, labels1=None, labels2=None, labels_matrix=None, enhance=False, batch1=None, batch2=None, threshold_quantile=0.95)[source]

Score function to get score of prediction.

Parameters:
  • idx (Iterable[int]) – Index of testing cells for scoring.

  • labels1 (torch.Tensor) – Column-wise matching labels.

  • labels2 (torch.Tensor) – Row-wise matching labels.

  • labels_matrix (torch.Tensor) – Matching labels.

  • enhance (bool optional) – Whether enable enhancement matching (e.g. bipartite matching), by default to be False.

  • batch1 (torch.Tensor optional) – Batch labels of modality 1, by default to be None.

  • batch2 (torch.Tensor optional) – Batch labels of modality 2, by default to be None.

  • threshold_quantile (float) – Parameter for batch_separated_bipartite_matching when enhance is set to true, which controls the sparsity.

Returns:

score – Accuracy of predicted matching between two modalities.

Return type:

float

to(device)[source]

Performs device conversion.

Parameters:

device (str) – Target device.

Returns:

self – Converted model.

Return type:

ScMoGCNWrapper

Modality prediction

class dance.modules.multi_modality.predict_modality.BabelWrapper(args, dim_in, dim_out)[source]

Babel class.

Parameters:
  • args (argparse.Namespace) – A Namespace object that contains arguments of Babel. For details of parameters in parser args, please refer to link (parser help document).

  • dim_in (int) – Input dimension.

  • dim_out (int) – Output dimension.

fit(x_train, y_train, max_epochs=500, val_ratio=0.15)[source]

Fit function for training.

Parameters:
  • x_train (torch.Tensor) – Training input modality.

  • y_train (torch.Tensor) – Training output modality.

  • max_epochs (int optional) – Maximum number of training epochs, by default to be 500.

  • val_ratio (int) – Validation ratio.

predict(test_mod1)[source]

Predict function to get prediction of target modality features.

Parameters:

test_mod1 (torch.Tensor) – Input modality features.

Returns:

pred – Prediction of target modality features.

Return type:

torch.Tensor

score(test_mod1, test_mod2)[source]

Score function to get score of prediction.

Parameters:
  • test_mod1 (torch.Tensor) – Input modality features.

  • test_mod2 (torch.Tensor) – Target modality features.

Returns:

score – RMSE loss of prediction.

Return type:

float

to(device)[source]

Performs device conversion.

Parameters:

device (str) – Target device.

Returns:

self – Converted model.

Return type:

BabelWrapper

class dance.modules.multi_modality.predict_modality.CMAE(hyperparameters)[source]

CMAE class.

Parameters:

hyperparameters (dictionary) – A dictionary that contains arguments of CMAE. For details of parameters in parser args, please refer to link (parser help document).

fit(train_mod1, train_mod2, aux_labels=None, checkpoint_directory='./checkpoint', val_ratio=0.15)[source]

Train CMAE.

Parameters:
  • train_mod1 (torch.Tensor) – Features of input modality.

  • train_mod2 (torch.Tensor) – Features of target modality.

  • aux_labels (torch.Tensor optional) – Auxiliary labels for extra supervision during training.

  • checkpoint_directory (str optional) – Path to the checkpoint file, by default to be ‘./checkpoint’.

  • val_ratio (float) – Ratio for automatic train-validation split.

Returns:

Return type:

None.

forward(mod1, mod2)[source]

Forward function for torch.nn.Module.

Parameters:
  • mod1 (torch.Tensor) – Input modality features.

  • mod2 (torch.Tensor) – Target modality features.

Returns:

  • x_ab (torch.Tensor) – Prediction of target modality from input modality.

  • x_ba (torch.Tensor) – Prediction of input modality from target modality.

predict(mod1)[source]

Predict function to get prediction of target modality features.

Parameters:

mod1 (torch.Tensor) – Input modality features.

Returns:

pred – Predicted features of target modality.

Return type:

torch.Tensor

resume(checkpoint_dir)[source]

Resume function to resume from checkpoint file.

Parameters:

checkpoint_dir (str) – Path to the checkpoint file.

Returns:

iterations – Current iteration number of resumed model.

Return type:

int

save(checkpoint_dir, iterations)[source]

Save function to save parameters to checkpoint file.

Parameters:
  • checkpoint_dir (str) – Path to the checkpoint file.

  • iterations (int) – Current number of training iterations.

Returns:

Return type:

None.

score(mod1, mod2)[source]

Score function to get score of prediction.

Parameters:
  • mod1 (torch.Tensor) – Input modality features.

  • mod2 (torch.Tensor) – Output modality features.

Returns:

score – RMSE loss of predicted output modality features.

Return type:

float

class dance.modules.multi_modality.predict_modality.MMVAE(subtask, params)[source]

MMVAE class.

Parameters:
  • subtask (str) – Name of the subtask which is composed of the name of two modality. This parameter will indicate some modality-specific features in the model.

  • params (argparse.Namespace) – A Namespace object that contains arguments of MMVAE. For details of parameters in parser args, please refer to link (parser help document).

fit(x_train, y_train, val_ratio=0.15)[source]

Fit function for training.

Parameters:
  • x_train (torch.Tensor) – Input modality for training.

  • y_train (torch.Tensor) – Target modality for training.

  • val_ratio (float) – Ratio for automatic train-validation split.

Returns:

Return type:

None.

forward(x)[source]

Forward function for torch.nn.Module.

Parameters:

x (list[torch.Tensor]) – Features of two modalities.

Returns:

  • qz_xs (list[torch.Tensor]) – Post prior of two modalities.

  • px_zs (list[torch.Tensor]) – likelihood of two modalities.

  • zss (list[torch.Tensor]) – Reconstruction results of two modalities.

predict(X)[source]

Score function to get score of prediction.

Parameters:

X (torch.Tensor) – Features of input modality and target modality.

Returns:

pred – Prediction of target modality from input modality.

Return type:

torch.Tensor

score(X, Y, metric='rmse')[source]

Score function to get score of prediction.

Parameters:
  • X (torch.Tensor) – Features of input modality.

  • Y (torch.Tensor) – Features of input modality.

  • metric (str optional) – Metric of the score function, by default to be ‘rmse’.

Returns:

score – Score of predicted matching, according to specified metric.

Return type:

float

class dance.modules.multi_modality.predict_modality.ScMoGCNWrapper(args)[source]

ScMoGCN class.

Parameters:

args (argparse.Namespace) – A Namespace object that contains arguments of ScMoGCN. For details of parameters in parser args, please refer to link (parser help document).

fit(g, y, split=None, eval=True, verbose=2, y_test=None, logger=None, sampling=False, eval_interval=1)[source]

Fit function for training.

Parameters:
  • g (dgl.DGLGraph) – Cell-feature graph contructed from the dataset.

  • y (torch.Tensor) – Labels of each training cell, a.k.a target modality features.

  • split (dictionary optional) – Cell indices for train-test split, needed when eval parameter set to be True.

  • eval (bool optional) – Whether to evaluate during training, by default to be True.

  • verbose (int optional) – Verbose level, by default to be 2 (i.e. print and logger).

  • y_test (torch.Tensor optional) – Labels of each testing cell, needed when eval parameter set to be True.

  • logger (file-object optional) – Log file, needed when verbose set to be 2.

  • sampling (bool optional) – Whether perform feature and cell sampling, by default to be False.

Returns:

Return type:

None.

fit_with_sampling(g, y, split=None, eval=True, verbose=2, y_test=None, logger=None, eval_interval=1)[source]

Fit function for training with graph sampling.

Parameters:
  • g (dgl.DGLGraph) – Cell-feature graph contructed from the dataset.

  • y (torch.Tensor) – Labels of each training cell, a.k.a target modality features.

  • split (dictionary optional) – Cell indices for train-test split, needed when eval parameter set to be True.

  • eval (bool optional) – Whether to evaluate during training, by default to be True.

  • verbose (int optional) – Verbose level, by default to be 2 (i.e. print and logger).

  • y_test (torch.Tensor optional) – Labels of each testing cell, needed when eval parameter set to be True.

  • logger (file-object optional) – Log file, needed when verbose set to be 2.

Returns:

Return type:

None.

predict(graph, idx=None, device='cpu')[source]

Predict function to get latent representation of data.

Parameters:
  • graph (dgl.DGLGraph) – Cell-feature graph contructed from the dataset.

  • idx (Iterable[int] optional) – Cell indices for prediction, by default to be None, where all the cells to be predicted.

  • device (str optional) – Well to perform predicting, by default to be ‘gpu’.

Returns:

pred – Predicted target modality features.

Return type:

torch.Tensor

score(g, idx, labels, device='cpu')[source]

Score function to get score of prediction.

Parameters:
  • g (dgl.DGLGraph) – Cell-feature graph contructed from the dataset.

  • idx (Iterable[int] optional) – Index of testing cells for scoring.

  • labels (torch.Tensor) – Ground truth label of cells, a.k.s target modality features.

  • device (str optional) – Well to perform predicting, by default to be ‘gpu’.

Returns:

loss – RMSE loss of predicted output modality features.

Return type:

float