
The models package hosts the suite of probabilistic models supported by Pyro-Velocity.


pyrovelocity.models.PyroVelocity(self, adata, input_type='raw', shared_time=True, model_type='auto', guide_type='auto', likelihood='Poisson', t_scale_on=False, plate_size=2, latent_factor='none', latent_factor_operation='selection', inducing_point_size=0, latent_factor_size=0, include_prior=False, use_gpu='auto', init=False, num_aux_cells=0, only_cell_times=True, decoder_on=False, add_offset=False, correct_library_size=True, cell_specific_kinetics=None, kinetics_num=None)

PyroVelocity is a class for constructing and training a Pyro model for probabilistic RNA velocity estimation. This model leverages the probabilistic programming language Pyro to estimate the parameters of models for the dynamics of RNA transcription, splicing, and degradation, providing the opportunity for insight into cellular states and associated state transitions. It makes use of AnnData, scvi-tools, and other scverse ecosystem libraries.

Public methods include training the model with various configurations, generating posterior samples for further analysis, and saving/loading the model for reproducibility and further analysis.


Name Type Description
use_gpu str Whether and which GPU to use.
cell_specific_kinetics Optional[str] Type of cell-specific kinetics.
k Optional[int] Number of kinetics.
layers List[str] List of layers in the dataset.
input_type str Type of input data.
module VelocityModule The Pyro module used for the velocity estimation model.
num_cells int Number of cells in the dataset.
num_samples int Number of posterior samples to generate.
_model_summary_string str Summary string for the model.
init_params_ Dict[str, Any] Initial parameters for the model.

For usage examples, including training the model and generating posterior samples, refer to the individual method docstrings.


Name Description
init PyroVelocity class for estimating RNA velocity and related tasks.
train Trains the PyroVelocity model using the provided data and configuration.
generate_posterior_samples Generates posterior samples for the given data using the trained
compute_statistics_from_posterior_samples Estimate statistics from posterior samples and add them to the
save_model Save the Pyro-Velocity model to a directory.
load_model Load the model from a directory with the same structure as that produced


pyrovelocity.models.PyroVelocity.__init__(adata, input_type='raw', shared_time=True, model_type='auto', guide_type='auto', likelihood='Poisson', t_scale_on=False, plate_size=2, latent_factor='none', latent_factor_operation='selection', inducing_point_size=0, latent_factor_size=0, include_prior=False, use_gpu='auto', init=False, num_aux_cells=0, only_cell_times=True, decoder_on=False, add_offset=False, correct_library_size=True, cell_specific_kinetics=None, kinetics_num=None)

PyroVelocity class for estimating RNA velocity and related tasks.


Name Type Description Default
adata AnnData An AnnData object containing the gene expression data. required
input_type str Type of input data. Can be “raw”, “knn”, or “raw_cpm”. Defaults to “raw”. 'raw'
shared_time bool Whether to use shared time. Defaults to True. True
model_type str Type of model to use. Defaults to “auto”. 'auto'
guide_type str Type of guide to use. Defaults to “auto”. 'auto'
likelihood str Type of likelihood to use. Defaults to “Poisson”. 'Poisson'
t_scale_on bool Whether to use t_scale. Defaults to False. False
plate_size int Size of the plate. Defaults to 2. 2
latent_factor str Type of latent factor. Defaults to “none”. 'none'
latent_factor_operation str Operation to perform on the latent factor. Defaults to “selection”. 'selection'
inducing_point_size int Size of inducing points. Defaults to 0. 0
latent_factor_size int Size of latent factors. Defaults to 0. 0
include_prior bool Whether to include prior information. Defaults to False. False
use_gpu Union[bool, int] Whether and which GPU to use. Defaults to 0. Can be False. 'auto'
init bool Whether to initialize the model. Defaults to False. False
num_aux_cells int Number of auxiliary cells. Defaults to 0. 0
only_cell_times bool Whether to use only cell times. Defaults to True. True
decoder_on bool Whether to use decoder. Defaults to False. False
add_offset bool Whether to add offset. Defaults to False. False
correct_library_size Union[bool, str] Whether to correct library size or method to correct. Defaults to True. True
cell_specific_kinetics Optional[str] Type of cell-specific kinetics. Defaults to None. None
kinetics_num Optional[int] Number of kinetics. Defaults to None. None


>>> # import necessary libraries
>>> import numpy as np
>>> import anndata
>>> from pyrovelocity.utils import pretty_log_dict, print_anndata, generate_sample_data
>>> from pyrovelocity.tasks.preprocess import copy_raw_counts
>>> from pyrovelocity.models._velocity import PyroVelocity
>>> # define fixtures
>>> try:
>>>     tmp = getfixture("tmp_path")
>>> except NameError:
>>>     import tempfile
>>>     tmp = tempfile.TemporaryDirectory().name
>>> doctest_model_path = str(tmp) + "/save_pyrovelocity_doctest_model"
>>> print(doctest_model_path)
>>> # setup sample data
>>> n_obs = 10
>>> n_vars = 5
>>> adata = generate_sample_data(n_obs=n_obs, n_vars=n_vars)
>>> copy_raw_counts(adata)
>>> print_anndata(adata)
>>> print(adata.X)
>>> print(adata.layers['spliced'])
>>> print(adata.layers['unspliced'])
>>> print(adata.obs['u_lib_size_raw'])
>>> print(adata.obs['s_lib_size_raw'])
>>> PyroVelocity.setup_anndata(adata)
>>> # train model with macroscopic validation set
>>> model = PyroVelocity(adata)
>>> model.train(max_epochs=5, train_size=0.8, valid_size=0.2, use_gpu="auto")
>>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30)
>>> print(posterior_samples.keys())
>>> assert isinstance(posterior_samples, dict), f"Expected a dictionary, got {type(posterior_samples)}"
>>> posterior_samples_log = pretty_log_dict(posterior_samples)
>>> model.save_model(doctest_model_path, overwrite=True)
>>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto")
>>> # train model with default parameters
>>> model = PyroVelocity(adata)
>>> model.train_faster(max_epochs=5, use_gpu="auto")
>>> model.save_model(doctest_model_path, overwrite=True)
>>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto")
>>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30)
>>> posterior_samples_log = pretty_log_dict(posterior_samples)
>>> print(posterior_samples.keys())
>>> # train model with specified batch size
>>> model = PyroVelocity(adata)
>>> model.train_faster_with_batch(batch_size=24, max_epochs=5, use_gpu="auto")
>>> model.save_model(doctest_model_path, overwrite=True)
>>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto")
>>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30)
>>> posterior_samples_log = pretty_log_dict(posterior_samples)
>>> print(posterior_samples.keys())
>>> # If running from an interactive session, the temporary directory
>>> # can be inspected to review the saved model files. When run as a
>>> # doctest it is automatically cleaned up after the test completes.
>>> print(f"Output located in {doctest_model_path}")



Trains the PyroVelocity model using the provided data and configuration.

The method leverages the Pyro library to train the model using the underlying data. It relies on the VelocityTrainingMixin to define the training logic.


**kwargs : dict, optional
    Additional keyword arguments to be passed to the underlying train method
    provided by the `VelocityTrainingMixin`.


pyrovelocity.models.PyroVelocity.generate_posterior_samples(adata=None, indices=None, batch_size=None, num_samples=100)

Generates posterior samples for the given data using the trained PyroVelocity model.

The method generates posterior samples by running the trained model on the provided data and returns a dictionary containing samples for each parameter.


Name Type Description Default
adata AnnData Anndata object containing the data for which posterior samples are to be computed. If not provided, the anndata used to initialize the model will be used. None
indices Sequence[int] Indices of cells in adata for which the posterior samples are to be computed. None
batch_size int The size of the mini-batches used during computation. If not provided, the entire dataset will be used. None
num_samples (int, default) 100): The number of posterior samples to compute for each parameter. 100


Type Description
Dict[str, ndarray] Dict[str, ndarray]: A dictionary containing the posterior samples for each parameter.


pyrovelocity.models.PyroVelocity.compute_statistics_from_posterior_samples(adata, posterior_samples, vector_field_basis='umap', ncpus_use=1)

Estimate statistics from posterior samples and add them to the posterior_samples dictionary. The names of the statistics incorporated into the dictionary are:

  • gene_ranking
  • original_spaces_embeds_magnitude
  • genes
  • vector_field_posterior_samples
  • vector_field_posterior_mean
  • fdri
  • embeds_magnitude
  • embeds_angle
  • ut_mean
  • st_mean
  • pca_vector_field_posterior_samples
  • pca_embeds_angle
  • pca_fdri

The following data are removed from the posterior_samples dictionary:

  • u
  • s
  • ut
  • st

Each of these sets requires further documentation.


Name Type Description Default
adata AnnData Anndata object containing the data for which posterior samples were computed. required
posterior_samples Dict[str, ndarray] Dictionary containing the posterior samples for each parameter. required
vector_field_basis str Basis for the vector field. Defaults to “umap”. 'umap'
ncpus_use int Number of CPUs to use for computation. Defaults to 1. 1


Type Description
Dict[str, ndarray] Dict[str, ndarray]: Dictionary containing the posterior samples with added statistics.


pyrovelocity.models.PyroVelocity.save_model(dir_path, prefix=None, overwrite=True, save_anndata=False, **anndata_write_kwargs)

Save the Pyro-Velocity model to a directory.

Dispatches to the save method of the inherited BaseModelClass which calls on a model state dictionary, variable names, and user attributes.


Name Type Description Default
dir_path str Path to the directory where the model will be saved. required
prefix Optional[str] Prefix to add to the saved files. Defaults to None. None
overwrite bool Whether to overwrite existing files. Defaults to True. True
save_anndata bool Whether to save the AnnData object. Defaults to False. False


pyrovelocity.models.PyroVelocity.load_model(dir_path, adata=None, use_gpu='auto', prefix=None, backup_url=None)

Load the model from a directory with the same structure as that produced by the save method.


Name Type Description Default
dir_path str Path to the directory where the model is saved. required
adata Optional[AnnData] Anndata object to load into the model. Defaults to None. None
use_gpu str Whether and which GPU to use. Defaults to “auto”. 'auto'
prefix Optional[str] Prefix to add to the saved files. Defaults to None. None
backup_url Optional[str] URL to download the model from. Defaults to None. None


Type Description
RuntimeError If the model is not an instance of PyroBaseModuleClass.


Type Description
BaseModelClass The loaded PyroVelocity model.


pyrovelocity.models.mrna_dynamics(tau, u0, s0, alpha, beta, gamma)

Computes the mRNA dynamics given temporal coordinate, parameter values, and initial conditions.

st_gamma_equals_beta for the case where the gamma parameter is equal to the beta parameter is taken from Equation 2.12 of

Li T, Shi J, Wu Y, Zhou P. On the mathematics of RNA velocity I: Theoretical analysis. CSIAM Transactions on Applied Mathematics. 2021;2: 1–55. doi:10.4208/


Name Type Description Default
tau Tensor Time points. required
u0 Tensor Initial value of u. required
s0 Tensor Initial value of s. required
alpha Tensor Alpha parameter. required
beta Tensor Beta parameter. required
gamma Tensor Gamma parameter. required


Type Description
Tuple[Tensor, Tensor] Tuple[Tensor, Tensor]: Tuple containing the final values of u and s.


>>> import torch
>>> tau = torch.tensor(2.0)
>>> u0 = torch.tensor(1.0)
>>> s0 = torch.tensor(0.5)
>>> alpha = torch.tensor(0.5)
>>> beta = torch.tensor(0.4)
>>> gamma = torch.tensor(0.3)
>>> mrna_dynamics(tau, u0, s0, alpha, beta, gamma)
(tensor(1.1377), tensor(0.9269))