Models
The models package hosts the suite of probabilistic models supported by Pyro-Velocity.
PyroVelocity
pyrovelocity.models.PyroVelocity(self, adata, input_type='raw', shared_time=True, model_type='auto', guide_type='auto', likelihood='Poisson', t_scale_on=False, plate_size=2, latent_factor='none', latent_factor_operation='selection', inducing_point_size=0, latent_factor_size=0, include_prior=False, use_gpu='auto', init=False, num_aux_cells=0, only_cell_times=True, decoder_on=False, add_offset=False, correct_library_size=True, cell_specific_kinetics=None, kinetics_num=None)
PyroVelocity is a class for constructing and training a Pyro model for probabilistic RNA velocity estimation. This model leverages the probabilistic programming language Pyro to estimate the parameters of models for the dynamics of RNA transcription, splicing, and degradation, providing the opportunity for insight into cellular states and associated state transitions. It makes use of AnnData, scvi-tools, and other scverse ecosystem libraries.
Public methods include training the model with various configurations, generating posterior samples for further analysis, and saving/loading the model for reproducibility and further analysis.
Attributes
Name | Type | Description |
---|---|---|
use_gpu | str | Whether and which GPU to use. |
cell_specific_kinetics | Optional[str] | Type of cell-specific kinetics. |
k | Optional[int] | Number of kinetics. |
layers | List [str] |
List of layers in the dataset. |
input_type | str | Type of input data. |
module | VelocityModule |
The Pyro module used for the velocity estimation model. |
num_cells | int | Number of cells in the dataset. |
num_samples | int | Number of posterior samples to generate. |
_model_summary_string | str | Summary string for the model. |
init_params_ | Dict[str, Any ] |
Initial parameters for the model. |
For usage examples, including training the model and generating posterior samples, refer to the individual method docstrings.
Methods
Name | Description |
---|---|
init | PyroVelocity class for estimating RNA velocity and related tasks. |
train | Trains the PyroVelocity model using the provided data and configuration. |
generate_posterior_samples | Generates posterior samples for the given data using the trained |
compute_statistics_from_posterior_samples | Estimate statistics from posterior samples and add them to the |
save_model | Save the Pyro-Velocity model to a directory. |
load_model | Load the model from a directory with the same structure as that produced |
init
pyrovelocity.models.PyroVelocity.__init__(adata, input_type='raw', shared_time=True, model_type='auto', guide_type='auto', likelihood='Poisson', t_scale_on=False, plate_size=2, latent_factor='none', latent_factor_operation='selection', inducing_point_size=0, latent_factor_size=0, include_prior=False, use_gpu='auto', init=False, num_aux_cells=0, only_cell_times=True, decoder_on=False, add_offset=False, correct_library_size=True, cell_specific_kinetics=None, kinetics_num=None)
PyroVelocity class for estimating RNA velocity and related tasks.
Parameters
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData | An AnnData object containing the gene expression data. | required |
input_type |
str | Type of input data. Can be “raw”, “knn”, or “raw_cpm”. Defaults to “raw”. | 'raw' |
shared_time |
bool | Whether to use shared time. Defaults to True. | True |
model_type |
str | Type of model to use. Defaults to “auto”. | 'auto' |
guide_type |
str | Type of guide to use. Defaults to “auto”. | 'auto' |
likelihood |
str | Type of likelihood to use. Defaults to “Poisson”. | 'Poisson' |
t_scale_on |
bool | Whether to use t_scale. Defaults to False. | False |
plate_size |
int | Size of the plate. Defaults to 2. | 2 |
latent_factor |
str | Type of latent factor. Defaults to “none”. | 'none' |
latent_factor_operation |
str | Operation to perform on the latent factor. Defaults to “selection”. | 'selection' |
inducing_point_size |
int | Size of inducing points. Defaults to 0. | 0 |
latent_factor_size |
int | Size of latent factors. Defaults to 0. | 0 |
include_prior |
bool | Whether to include prior information. Defaults to False. | False |
use_gpu |
Union[bool, int] | Whether and which GPU to use. Defaults to 0. Can be False. | 'auto' |
init |
bool | Whether to initialize the model. Defaults to False. | False |
num_aux_cells |
int | Number of auxiliary cells. Defaults to 0. | 0 |
only_cell_times |
bool | Whether to use only cell times. Defaults to True. | True |
decoder_on |
bool | Whether to use decoder. Defaults to False. | False |
add_offset |
bool | Whether to add offset. Defaults to False. | False |
correct_library_size |
Union[bool, str] | Whether to correct library size or method to correct. Defaults to True. | True |
cell_specific_kinetics |
Optional[str] | Type of cell-specific kinetics. Defaults to None. | None |
kinetics_num |
Optional[int] | Number of kinetics. Defaults to None. | None |
Examples
>>> # import necessary libraries
>>> import numpy as np
>>> import anndata
>>> from pyrovelocity.utils import pretty_log_dict, print_anndata, generate_sample_data
>>> from pyrovelocity.tasks.preprocess import copy_raw_counts
>>> from pyrovelocity.models._velocity import PyroVelocity
...>>> # define fixtures
>>> try:
>>> tmp = getfixture("tmp_path")
>>> except NameError:
>>> import tempfile
>>> tmp = tempfile.TemporaryDirectory().name
>>> doctest_model_path = str(tmp) + "/save_pyrovelocity_doctest_model"
>>> print(doctest_model_path)
...>>> # setup sample data
>>> n_obs = 10
>>> n_vars = 5
>>> adata = generate_sample_data(n_obs=n_obs, n_vars=n_vars)
>>> copy_raw_counts(adata)
>>> print_anndata(adata)
>>> print(adata.X)
>>> print(adata.layers['spliced'])
>>> print(adata.layers['unspliced'])
>>> print(adata.obs['u_lib_size_raw'])
>>> print(adata.obs['s_lib_size_raw'])
>>> PyroVelocity.setup_anndata(adata)
...>>> # train model with macroscopic validation set
>>> model = PyroVelocity(adata)
>>> model.train(max_epochs=5, train_size=0.8, valid_size=0.2, use_gpu="auto")
>>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30)
>>> print(posterior_samples.keys())
>>> assert isinstance(posterior_samples, dict), f"Expected a dictionary, got {type(posterior_samples)}"
>>> posterior_samples_log = pretty_log_dict(posterior_samples)
>>> model.save_model(doctest_model_path, overwrite=True)
>>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto")
...>>> # train model with default parameters
>>> model = PyroVelocity(adata)
>>> model.train_faster(max_epochs=5, use_gpu="auto")
>>> model.save_model(doctest_model_path, overwrite=True)
>>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto")
>>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30)
>>> posterior_samples_log = pretty_log_dict(posterior_samples)
>>> print(posterior_samples.keys())
...>>> # train model with specified batch size
>>> model = PyroVelocity(adata)
>>> model.train_faster_with_batch(batch_size=24, max_epochs=5, use_gpu="auto")
>>> model.save_model(doctest_model_path, overwrite=True)
>>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto")
>>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30)
>>> posterior_samples_log = pretty_log_dict(posterior_samples)
>>> print(posterior_samples.keys())
...>>> # If running from an interactive session, the temporary directory
>>> # can be inspected to review the saved model files. When run as a
>>> # doctest it is automatically cleaned up after the test completes.
>>> print(f"Output located in {doctest_model_path}")
train
pyrovelocity.models.PyroVelocity.train(**kwargs)
Trains the PyroVelocity model using the provided data and configuration.
The method leverages the Pyro library to train the model using the underlying data. It relies on the VelocityTrainingMixin
to define the training logic.
Args:
**kwargs : dict, optional
Additional keyword arguments to be passed to the underlying train method
provided by the `VelocityTrainingMixin`.
generate_posterior_samples
pyrovelocity.models.PyroVelocity.generate_posterior_samples(adata=None, indices=None, batch_size=None, num_samples=100)
Generates posterior samples for the given data using the trained PyroVelocity model.
The method generates posterior samples by running the trained model on the provided data and returns a dictionary containing samples for each parameter.
Parameters
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData | Anndata object containing the data for which posterior samples are to be computed. If not provided, the anndata used to initialize the model will be used. | None |
indices |
Sequence[int] | Indices of cells in adata for which the posterior samples are to be computed. |
None |
batch_size |
int | The size of the mini-batches used during computation. If not provided, the entire dataset will be used. | None |
num_samples |
(int, default ) |
100): The number of posterior samples to compute for each parameter. | 100 |
Returns
Type | Description |
---|---|
Dict[str, ndarray] | Dict[str, ndarray]: A dictionary containing the posterior samples for each parameter. |
compute_statistics_from_posterior_samples
pyrovelocity.models.PyroVelocity.compute_statistics_from_posterior_samples(adata, posterior_samples, vector_field_basis='umap', ncpus_use=1)
Estimate statistics from posterior samples and add them to the posterior_samples
dictionary. The names of the statistics incorporated into the dictionary are:
gene_ranking
original_spaces_embeds_magnitude
genes
vector_field_posterior_samples
vector_field_posterior_mean
fdri
embeds_magnitude
embeds_angle
ut_mean
st_mean
pca_vector_field_posterior_samples
pca_embeds_angle
pca_fdri
The following data are removed from the posterior_samples
dictionary:
u
s
ut
st
Each of these sets requires further documentation.
Parameters
Name | Type | Description | Default |
---|---|---|---|
adata |
AnnData | Anndata object containing the data for which posterior samples were computed. | required |
posterior_samples |
Dict[str, ndarray] | Dictionary containing the posterior samples for each parameter. | required |
vector_field_basis |
str | Basis for the vector field. Defaults to “umap”. | 'umap' |
ncpus_use |
int | Number of CPUs to use for computation. Defaults to 1. | 1 |
Returns
Type | Description |
---|---|
Dict[str, ndarray] | Dict[str, ndarray]: Dictionary containing the posterior samples with added statistics. |
save_model
pyrovelocity.models.PyroVelocity.save_model(dir_path, prefix=None, overwrite=True, save_anndata=False, **anndata_write_kwargs)
Save the Pyro-Velocity model to a directory.
Dispatches to the save
method of the inherited BaseModelClass
which calls torch.save
on a model state dictionary, variable names, and user attributes.
Parameters
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str | Path to the directory where the model will be saved. | required |
prefix |
Optional[str] | Prefix to add to the saved files. Defaults to None. | None |
overwrite |
bool | Whether to overwrite existing files. Defaults to True. | True |
save_anndata |
bool | Whether to save the AnnData object. Defaults to False. | False |
load_model
pyrovelocity.models.PyroVelocity.load_model(dir_path, adata=None, use_gpu='auto', prefix=None, backup_url=None)
Load the model from a directory with the same structure as that produced by the save method.
Parameters
Name | Type | Description | Default |
---|---|---|---|
dir_path |
str | Path to the directory where the model is saved. | required |
adata |
Optional[AnnData] | Anndata object to load into the model. Defaults to None. | None |
use_gpu |
str | Whether and which GPU to use. Defaults to “auto”. | 'auto' |
prefix |
Optional[str] | Prefix to add to the saved files. Defaults to None. | None |
backup_url |
Optional[str] | URL to download the model from. Defaults to None. | None |
Raises
Type | Description |
---|---|
RuntimeError | If the model is not an instance of PyroBaseModuleClass. |
Returns
Type | Description |
---|---|
BaseModelClass | The loaded PyroVelocity model. |
mrna_dynamics
pyrovelocity.models.mrna_dynamics(tau, u0, s0, alpha, beta, gamma)
Computes the mRNA dynamics given temporal coordinate, parameter values, and initial conditions.
st_gamma_equals_beta
for the case where the gamma parameter is equal to the beta parameter is taken from Equation 2.12 of
Li T, Shi J, Wu Y, Zhou P. On the mathematics of RNA velocity I: Theoretical analysis. CSIAM Transactions on Applied Mathematics. 2021;2: 1–55. doi:10.4208/csiam-am.so-2020-0001
Parameters
Name | Type | Description | Default |
---|---|---|---|
tau |
Tensor | Time points. | required |
u0 |
Tensor | Initial value of u. | required |
s0 |
Tensor | Initial value of s. | required |
alpha |
Tensor | Alpha parameter. | required |
beta |
Tensor | Beta parameter. | required |
gamma |
Tensor | Gamma parameter. | required |
Returns
Type | Description |
---|---|
Tuple[Tensor, Tensor] | Tuple[Tensor, Tensor]: Tuple containing the final values of u and s. |
Examples
>>> import torch
>>> tau = torch.tensor(2.0)
>>> u0 = torch.tensor(1.0)
>>> s0 = torch.tensor(0.5)
>>> alpha = torch.tensor(0.5)
>>> beta = torch.tensor(0.4)
>>> gamma = torch.tensor(0.3)
>>> mrna_dynamics(tau, u0, s0, alpha, beta, gamma)
1.1377), tensor(0.9269)) (tensor(