Probabilistic inference in dynamical systems

Author

Pyrovelocity Team

Abstract

This notebook provides a discussion and demonstration of one way to proceed in performing probabilistic inference in dynamical systems. In particular we prepare and simulate a deterministic dynamical system and then perform probabilistic inference on the latent variables of the system. This illustrates how to combine these two approaches, which is essential for understanding Pyro-Velocity.

Keywords

single-cell transcriptomics, probabilistic modeling, dynamical systems, RNA velocity

1 Model

1.1 Description

Here we describe the model we will use to demonstrate probabilistic inference in dynamical systems. We begin with a candidate effective theory [1, 2] of gene transcription and splicing following [35] for the model and [6] for the presentation of the analysis. In particular we show in Table 1 the variables and parameters of the model together with their units and rough order of magnitude estimates for their values [7].

Table 1: Variables and parameters of the transcription-splicing-degradation model of [3] with order of magnitude estimates for ranges based on [7] or references therein.
Symbol Description Units \(O(-)\) Estimate Note
\(u\) Number of pre-mRNA molecules/cell \(10^0 - 10^4\) Wide range accounts for low to high gene expression levels.
\(s\) Number of mRNA molecules/cell \(10^0 - 10^5\) Similar to \(u\); depends on gene expression and stability of mRNA.
\(t\) Time seconds (\(s\)) to hours (\(hr\)) \(5s\) - \(48hr\) Depends on the duration of the experimental observation.
\(\alpha\) Production rate of pre-mRNA molecules/(cell·hr) \(10^0 - 10^3\) Many transcripts are produced at rates in the range or 1 to 1000 per hour.
\(\beta\) Splicing rate of pre-mRNA \(hr^{-1}\) \(10^{-1} - 10^2\) Many Pre-mRNA to mRNA splicing rates are in the range of 1 minute to 6 hours.
\(\gamma\) Degradation rate of mRNA \(hr^{-1}\) \(10^{-2} - 10^0\) Many mRNA half-lives are in the range of a half-hour to a day.

Given state variables representing concentrations of pre-mRNA, \(u\), and mRNA, \(s\), we have the following ordinary differential equations taken from [3] (please refer again to Table 1 for the meaning of the variables and parameters),

\[\begin{align} \frac{du}{dt} & = \alpha - \beta u \label{eq-dudt}, \\ \frac{ds}{dt} & = \beta u - \gamma s \label{eq-dsdt}, \end{align}\]

which proposes a mean-field model for the dynamics of transcription and splicing on a continuous state space. It is important to note this is only one among an eventual ensemble of models that, with appropriate analysis and inference, may be organized into a hierarchy of such models according to their domains of validity relative to the scales involved in and resolution of observations of the relevant phenomenon [1]. Of course, in the context of single gene transcription without explicit modeling of interactions, this ensemble of models would include the various adaptations and extensions of the so-called random telegraph model and the experimental justification for its consideration [810]. As an example of a minimal extension, \(\eqref{eq-dudt}\) and \(\eqref{eq-dsdt}\) are usually presented with the concept that the parameter \(\alpha\) could better account for the external inputs to the regulation of transcription of the gene if it were allowed to be a function of time, \(\alpha(t)\). However, because this complicates or eliminates analytical tractability, and thus maximum likelihood inference, approximation of this function may be piecewise constant [5]. We will return to this point later, but, for the purpose of simply illustrating the relationsips among analyzing, simulating, and performing inference upon the parameters of a dynamical system, we will assume that \(\alpha\) and other parameters are constant values.

1.2 Dimensional analysis

It is common to perform a dimensional analysis of dynamical models to ensure the units of the variables and parameters are consistent and reduce the total number of parameters by the number of dimensions of the associated dilation group symmetry [11]. Intuitively, this maps sets of parameters to associated equivalence classes of similar dynamics. For reference, in addition to dimensional analysis and nondimensionalization, the procedure we make use of depends on what is frequently referred to as the Buckingham \(\Pi\) theorem [6].

Table 1 together with (\(\ref{eq-dudt}\)-\(\ref{eq-dsdt}\)) contain essentially all the information we require to get started. The dimensions of the variables and parameters together with dimensionless ones are provided in Table 2.

Table 2: Variables and parameters of the transcription-splicing-degradation model together with their fundamental units and their relations to dimensionless parameters.
Dimensioned Parameter Relation to Rescaling Parameters Fundamental Units
Production Rate (\(\alpha\)) \(\alpha = U_0 \beta\) molecules/(cell·time)
Splicing Rate (\(\beta\)) Reference Scale for \(t^*\) \(1/\text{time}\)
Degradation Rate (\(\gamma\)) \(\gamma = \gamma^* \beta\) \(1/\text{time}\)
Pre-mRNA Concentration (\(u\)) \(u = u^* U_0\) molecules/cell
mRNA Concentration (\(s\)) \(s = s^* U_0\) molecules/cell
Time (\(t\)) \(t = t^* / \beta\) time

Note that molecules are dimensionless numbers while cells have units associated to their volume (i.e. the cube of a distance or length \(L^3\)). Usually we would just write volume but we retain “cell” to associate to the object that determines the relevant volume in this case. The dimensionless parameters are defined in Table 3.

Table 3: Dimensionless Parameters, definitions, and intuitive Descriptions
Dimensionless Variables and Parameters Definition Description
\(u^{\ast}\) , \(s^{\ast}\) \(u / U_0\), \(s / U_0\) Characteristic scale of (pre-)mRNA concentration based on the balance between production and splicing rates.
\(t^*\) \(\beta t\) Characteristic time scale relative to the splicing rate.
\(\gamma^*\) \(\gamma / \beta\) Relative degradation rate, comparing the degradation rate of mRNA to the splicing rate, indicating the stability or turnover rate of mRNA relative to its production/splicing.

Given our previously defined parameter range for the splicing rate \(\beta\), we can interpret the dimensionless parameter \(t^{\ast}\) according to Table 4.

Table 4: This table interprets the dimensionless time \(t^{\ast}\) across various splicing rates \(\beta\) (hr\(^{-1}\)), demonstrating how real time and expected splicing events scale with changes in \(\beta\). Each unit of \(t^{\ast}\) represents a normalized measure of time relative to the splicing rate, with the final three columns indicating the expected number of splicing events per molecule of pre-mRNA for 1, 10, and 100 units of \(t^{\ast}\) is indeed the meaning of time units for this scaling rule. These columns confirm that the dimensionless time directly reflects the expected splicing activity, endowing otherwise meaningless or difficult to interpret numbers into a universally applicable framework based on scales intrinsic to a given instance of the model.
\(\beta\) (hr\(^{-1}\)) 1 unit of \(t^*\) (hr) 10 units of \(t^*\) (hr) 100 units of \(t^*\) (hr) 1 unit of \(t^{\ast}\) 10 units of \(t^{\ast}\) 100 units of \(t^{\ast}\)
\(10^{-1}\) 10 100 1000 1 10 100
\(10^0\) 1 10 100 1 10 100
\(10^1\) 0.1 1 10 1 10 100
\(10^2\) 0.01 0.1 1 1 10 100

Combining variables and parameters gives six total. We have two fundamental dimensions (time and length) we can eliminate to arrive at four essential variables and parameters. To derive these, we begin declaring the equations in python using sympy and substitute the dimensionless variables and parameters to obtain the dimensionless equations.

Code
from sympy import (
    diff,
    Eq,
    Function,
    init_printing,
    solve,
    symbols,
)
from IPython.display import display

init_printing(use_latex=True)

alpha, beta, gamma, gamma_star = symbols("alpha beta gamma gamma_star")
U0, T0 = symbols("U0 T0")
t, t_star = symbols(
    "t t_star",
)

u = Function("u")(t)
s = Function("s")(t)
u_star = Function("u_star")(t_star)
s_star = Function("s_star")(t_star)

du_dt = Eq(diff(u, t), alpha - beta * u)
ds_dt = Eq(diff(s, t), beta * u - gamma * s)
display(du_dt)
display(ds_dt)

\(\displaystyle \frac{d}{d t} u{\left(t \right)} = \alpha - \beta u{\left(t \right)}\)

\(\displaystyle \frac{d}{d t} s{\left(t \right)} = \beta u{\left(t \right)} - \gamma s{\left(t \right)}\)

and define dimensionless variables and parameters as

U0_eq = Eq(U0, alpha / beta)
T0_eq = Eq(T0, 1 / beta)

u_dimless = Eq(u_star, u / U0)
s_dimless = Eq(s_star, s / U0)
t_dimless = Eq(t_star, beta * t)

Now we change the variables in the equations to the dimensionless ones

Code
du_dt_dimless = Eq(
    beta * diff(u_star * alpha / beta, t_star),
    du_dt.rhs.subs(
        {
            u: u_star * alpha / beta,
            s: s_star * alpha / beta,
            t: t_star / beta,
            gamma: gamma_star * beta,
        }
    ),
).simplify()
du_dt_dimless

\(\displaystyle \alpha \left(1 - u_{star}{\left(t_{star} \right)}\right) = \alpha \frac{d}{d t_{star}} u_{star}{\left(t_{star} \right)}\)

and

Code
ds_dt_dimless = Eq(
    beta * diff(s_star * alpha / beta, t_star),
    ds_dt.rhs.subs(
        {
            u: u_star * alpha / beta,
            s: s_star * alpha / beta,
            t: t_star / beta,
            gamma: gamma_star * beta,
        }
    ),
).simplify()

ds_dt_dimless

\(\displaystyle \alpha \frac{d}{d t_{star}} s_{star}{\left(t_{star} \right)} = \alpha \left(- \gamma_{star} s_{star}{\left(t_{star} \right)} + u_{star}{\left(t_{star} \right)}\right)\)

We see that we correctly arrive at the equivalent dimensionless system of equations

\[\begin{align} \frac{du^{\ast}}{dt^{\ast}} & = 1 - u^{\ast} \label{eq-dustardtstar}, \\ \frac{ds^{\ast}}{dt^{\ast}} & = u^{\ast} - \gamma^{\ast} s^{\ast}, \label{eq-dsstardtstar} \end{align}\]

in \(\eqref{eq-dustardtstar}\) and \(\eqref{eq-dsstardtstar}\), which contains the \(6 - 2 = 4\) variables and parameters: \(u^{\ast}, s^{\ast}, t^{\ast}, \gamma^{\ast}\).

We’ll solve the dimensionless system analytically, focusing on what can be understood about the system's dynamics. Given specific initial conditions \(u^{\ast}(0) = u^{\ast}_0\) and \(s^{\ast}(0) = s^{\ast}_0\), which we restore to arbitrary \(t^{\ast}_0\) at the end of the derivation, the first equation

\[ \frac{du^{\ast}}{dt^{\ast}} = 1 - u^{\ast}, \]

can be solved using separation of variables

\[ u^{\ast}(t^{\ast}) = C_1 e^{-t^{\ast}} + 1. \]

\(C_1\) is of course the constant determined by the initial conditions, which can be shown to be \(C_1 = u^{\ast}_0 - 1,\) giving us

\[ u^{\ast}(t^{\ast}) = (u^{\ast}_0 - 1) e^{-t^{\ast}} + 1. \]

This equation describes how the concentration of pre-mRNA, in units of the characteristic concentration scale, \(U_0 = \alpha/\beta\), evolves over the timescale characterized by the inverse of the splicing rate, \(1/\beta\). The term \(e^{-t^{\ast}}\) indicates that any deviation of \(u^{\ast}\) from its steady state \(1\) decays exponentially over time.

Now, we substitute the expression for \(u^{\ast}(t^{\ast})\) into the differential equation for mRNA concentration

\[ \frac{ds^{\ast}}{dt^{\ast}} = u^{\ast} - \gamma^{\ast} s^{\ast} \]

which gives

\[ \frac{ds^{\ast}}{dt^{\ast}} = [(u^{\ast}_0 - 1) e^{-t^{\ast}} + 1] - \gamma^{\ast} s^{\ast} \]

This equation is a non-homogeneous first-order linear ODE for \(s^{\ast}\). To solve it, we can use an integrating factor \(e^{\gamma^{\ast} t^{\ast}}\), which gives

\[ e^{\gamma^{\ast} t^{\ast}} \frac{ds^{\ast}}{dt^{\ast}} + \gamma^{\ast} e^{\gamma^{\ast} t^{\ast}} s^{\ast} = (u^{\ast}_0 - 1) e^{(\gamma^{\ast}-1) t^{\ast}} + e^{\gamma^{\ast} t^{\ast}} \]

Integrating both sides with respect to \(t^{\ast}\) yields \(s^{\ast}(t^{\ast})\)

\[ s^{\ast}(t^{\ast}) = C_2 e^{-\gamma^{\ast} t^{\ast}} + \frac{u^{\ast}_0 - 1}{\gamma^{\ast} - 1} e^{-t^{\ast}} + \frac{1}{\gamma^{\ast}} \]

where \(C_2\) is another integration constant determined by initial conditions. Incorporating the initial condition \(s^{\ast}(0) = s^{\ast}_0\) into the solution for \(s^{\ast}(t^{\ast})\) to determine \(C_2\)

\[ C_2 = s^{\ast}_0 - \frac{u^{\ast}_0 - 1}{\gamma^{\ast} - 1} - \frac{1}{\gamma^{\ast}}. \]

With \(C_2\) determined, we define

\[ \xi = \frac{u^{\ast}_0 - 1}{\gamma^{\ast} - 1}. \]

The complete solutions for the dimensionless system for \(\gamma^{\ast} \neq 1\) are then given by

\[\begin{align} u^{\ast}(t^{\ast}) & = \underbrace{1}_{\text{Steady State}} + \underbrace{(u^{\ast}_0 - 1) e^{-t^{\ast}}}_{\text{Transient Component}}, \label{eq-ustar-dimless} \\ s^{\ast}(t^{\ast}) & = \underbrace{\frac{1}{\gamma^{\ast}}}_{\text{Steady State}} + \underbrace{\left( s^{\ast}_0 - \xi - \frac{1}{\gamma^{\ast}} \right) e^{-\gamma^{\ast} t^{\ast}} + \xi e^{-t^{\ast}}}_{\text{Transient Components}}. \label{eq-sstar-dimless} \end{align}\]

For the case \(\gamma^{\ast} = 1\), we can solve \(\eqref{eq-dustardtstar}\) and \(\eqref{eq-dsstardtstar}\) to find the solution for \(u^{\ast}\) is unchanged and \(s^{\ast}\) is then given by

\[ s^{\ast}(t^{\ast}) = 1 + (s^{\ast}_0 - 1) e^{-t^{\ast}} + (u^{\ast}_0 - 1) t^{\ast} e^{-t^{\ast}}. \]

Note we can restore an arbitrary initial timepoint putting \(t^{\ast} \rightarrow t^{\ast} - t^{\ast}_0\).

The dimensionless pre-mRNA dynamics \(u^{\ast}(t^{\ast})\) given in \(\eqref{eq-ustar-dimless}\) settles to its steady state value of 1 at rate 1, regardless of the initial concentration, reflecting the balance between transcription and splicing rates. Note again that this does not refer to an arbitrary value of \(1\) but rather represents the equivalence class of ratios \(\alpha/\beta\).

The dimensionless mRNA dynamics \(s^{\ast}(t^{\ast})\) given in \(\eqref{eq-sstar-dimless}\) reveals a slightly more complicated balance of forces. The first transient term shows an exponential decay influenced by \(\gamma^{\ast}\),the degradation rate relative to the splicing rate, when \(\gamma^{\ast}\) is not precisely \(1\). The second transient term reflects how the initial deviation in pre-mRNA concentration from its steady state affects mRNA levels over time. The steady state of \(s^{\ast}\) simply reflects the balance of the splicing rate, which is the only source of mRNA production in this model, and the degradation rate also given, just as in the case of the pre-mRNA, in units of the transcription rate relative to the splicing rate.

The analytical solutions to the dimensionless system reveal how the pre-mRNA and mRNA concentrations evolve over time towards their steady states. The exponential terms highlight the system's inherent timescales: \(u^{\ast}\) relaxes to its steady state independently, while \(s^{\ast}\)’s dynamics are coupled to both \(u^{\ast}\) and its own degradation rate. While this system is extremely simplified relative to well-known biology, as a stylized model it still indicates the interplay among transcription, splicing, and degradation rates.

While it is trivial to solve this system analytically, in more general cases this will not be possible. So we will proceed to simulate the system to characterize the possible solution sets and then perform probabilistic inference on the latent variables.

2 Data

If we are given samples of pre-mRNA and mRNA counts along with the sampling time of each cell, we can describe the data set \(\mathcal{D}\) consisting of pre-mRNA and mRNA counts for a number of genes \(G\) across a number of cells \(N\)

\[\begin{align} \mathcal{D} = \left\{ (t_j, u_{ij}, s_{ij}) \mid i \in \{1, \ldots, G\}, j \in \{1, \ldots, N\} \right\} \label{eq-dataset}, \end{align}\]

where

  • \(t_j\) represents the sampling time for cell \(j\),
  • \(u_{ij}\) represents the count of pre-mRNA for gene \(i\) in cell \(j\),
  • \(s_{ij}\) represents the count of mRNA for the same gene \(i\) in the same cell \(j\),
  • \(G\) is the total number of genes in the study,
  • \(N\) is the total number of cells sampled.

Roughly speaking, the inference problem involves

  1. Normalization: Transforming the observed counts \(u_{ij}\) and \(s_{ij}\) to their dimensionless counterparts \(u^{\ast}_{ij}\) and \(s^{\ast}_{ij}\) using a reasonable concentration scale \(U_0\).
  2. Model Fitting: Applying statistical methods to estimate the latent variables \(\Theta = \left( \gamma^{\ast}, u^{\ast}_0, s^{\ast}_0 \right)\) that best fit the observed data.
  3. Evaluation: Assessing the fit of the model and the estimated parameters’ biological plausibility and consistency across different genes and cells.

We will return later to the alternative version of this problem where \(t_j\) are not observed and must be inferred along with the latent variables \(\Theta = \left( \gamma^{\ast}, u^{\ast}_0, s^{\ast}_0, t^{\ast} \right)\).

This representation of the dataset and the associated objective of statistical learning provides a solid foundation for applying probabilistic modeling techniques to calibrate an inference procedure and evaluate the plausibility of this model in such a manner that we can eventually compare multiple candidate models to one another.

3 Simulation

We will primarily focus on simulating the system, since this will generalize to more complicated models. However, we will confirm our simulations for this first and simplest model recapitulate its analytical solution derived above.

As a simple example, we will simulate the system for a single set of initial conditions and parameters.

Code
import diffrax
from pyrovelocity.models import solve_transcription_splicing_model
from pyrovelocity.models import solve_transcription_splicing_model_analytical
from pyrovelocity.logging import configure_logging
from jax import numpy as jnp

logger = configure_logging("nbs")

ts0 = jnp.linspace(0.1, 4.0, 40)
ts1 = jnp.linspace(4.0 + (10.0 - 4.0) / 20, 10.0, 20)
ts = jnp.concatenate([ts0, ts1])
initial_state = jnp.array([0.1, 0.1])
params = (1.00,)
colormap_name = "RdBu"

solution = solve_transcription_splicing_model(
    ts,
    initial_state,
    params,
)
analytical_solution = solve_transcription_splicing_model_analytical(
    ts,
    initial_state,
    params,
)
analytical_simulation_error = diffrax.Solution(
    t0=ts[0],
    t1=ts[-1],
    ts=ts,
    ys=analytical_solution.ys - solution.ys,
    interpolation=None,
    stats={},
    result=diffrax.RESULTS.successful,
    solver_state=None,
    controller_state=None,
    made_jump=None,
)

logger.info(
    f"\nTrajectory preview times:\n{solution.ts[:3]}\n"
    f"\nSimulated trajectory values:\n{solution.ys[:5]}\n"
    f"\nAnalytical trajectory values:\n{analytical_solution.ys[:5]}\n\n"
    f"\nTrajectory preview times:\n{solution.ts[-3:]}\n"
    f"\nSimulated trajectory values:\n{solution.ys[-3:]}\n"
    f"\nAnalytical trajectory values:\n{analytical_solution.ys[-3:]}\n\n"
)
[21:26:06] INFO     nbs                                                                                            
                    Trajectory preview times:                                                                      
                    [0.1 0.2 0.3]                                                                                  
                                                                                                                   
                    Simulated trajectory values:                                                                   
                    [[0.1        0.1       ]                                                                       
                     [0.18564637 0.10421096]                                                                       
                     [0.26314226 0.11577081]                                                                       
                     [0.33326364 0.13324267]                                                                       
                     [0.3967121  0.15539673]]                                                                      
                                                                                                                   
                    Analytical trajectory values:                                                                  
                    [[0.10000002 0.10000002]                                                                       
                     [0.18564636 0.10421099]                                                                       
                     [0.2631424  0.11577088]                                                                       
                     [0.33326364 0.13324271]                                                                       
                     [0.396712   0.1553968 ]]                                                                      
                                                                                                                   
                                                                                                                   
                    Trajectory preview times:                                                                      
                    [ 9.400001  9.7      10.      ]                                                                
                                                                                                                   
                    Simulated trajectory values:                                                                   
                    [[0.99991643 0.99915135]                                                                       
                     [0.9999378  0.99935234]                                                                       
                     [0.999954   0.9995065 ]]                                                                      
                                                                                                                   
                    Analytical trajectory values:                                                                  
                    [[0.99991775 0.99915254]                                                                       
                     [0.999939   0.9993538 ]                                                                       
                     [0.9999548  0.9995078 ]]                                                                      
                                                                                                                   
                                                                                                                   

We can visualize the results of the simulation by plotting the pre-mRNA and mRNA concentrations over time as shown in Figure 1,

Code
from pyrovelocity.plots import plot_deterministic_simulation_trajectories

plot_deterministic_simulation_trajectories(
    solution=solution,
    title_prefix="TSD Model Simulated",
    colormap_name=colormap_name,
)
Figure 1: Simulated trajectories of pre-mRNA and mRNA concentrations over time for the transcription-splicing-degradation model.

We see that if we plot the analytical solution, we get the same result as shown in Figure 2 to within the expected error of the numerical simulation (\(O(10^{-5} - 10^{-8})\)).

Code
from pyrovelocity.plots import plot_deterministic_simulation_trajectories

plot_deterministic_simulation_trajectories(
    solution=analytical_simulation_error,
    title_prefix="TSD Model Analytical-Simulation Error",
    colormap_name=colormap_name,
)
Figure 2: Absolute error in analytical vs simulated trajectories of pre-mRNA and mRNA concentrations over time for the transcription-splicing-degradation model.

The same simulated trajectories from Figure 1 are shown in the phase space given by \(u^{\ast} \otimes s^{\ast}\) in Figure 3.

Code
from pyrovelocity.plots import plot_deterministic_simulation_phase_portrait

plot_deterministic_simulation_phase_portrait(
    solution=solution,
    title_prefix="TSD Model",
    colormap_name=colormap_name,
)
Figure 3: Phase portrait of the transcription-splicing-degradation model.

4 Inference

Now that we have illustrated how to simulate a prototypical system, we can proceed to define a probabilistic model and perform inference on its latent variables given data like \(\eqref{eq-dataset}\). In general we should consider multiple models with different assumptions and levels of complexity, but for now we will focus on one of the simpler models to illustrate the process. In particular, in most cases, when performing inference on the latent variables of a dynamical system, we assume that the observation time points are fixed or otherwise observed. We will begin with this assumption and then proceed to discuss how to relax it.

In the case where we can indeed estimate analytical solutions to the system, we can consider a model with the structure proposed in (\(\ref{eq-init-conds-priors-gene}\)\(\ref{eq-s-obs-cell}\)) where, given a number of genes \(G\) and cells \(N\) that are measured at \(K_j\) distinct times, then for each \((i,j,k) \in \{1, \ldots, G\} \otimes \{1, \ldots, N\} \otimes \{1, \ldots, K_j\}\) we have

\[\begin{align} u^{\ast}_{0i}, s^{\ast}_{0i} &\sim \text{LogNormal}(\mu_{0}, \sigma_{0}^2), \label{eq-init-conds-priors-gene} \\ \gamma^{\ast}_i &\sim \text{LogNormal}(\mu_{\gamma}, \sigma_{\gamma}^2), \label{eq-gamma-prior-gene} \\ \sigma_{ui}, \sigma_{si} &\sim \text{HalfNormal}(\mu_{\sigma}, \sigma_{\sigma}^2), \label{eq-noise-std-priors-gene} \\ {u^{\ast}}^{k}_{ij} &= 1 + (u^{\ast}_{0i} - 1) \cdot e^{-{t^{\ast}}^k_j}, \label{eq-u-star-model-cell} \\ {s^{\ast}}^{k}_{ij} &= \begin{cases} \frac{1}{\gamma^{\ast}_i} + \left( s^{\ast}_{0i} - \xi_i - \frac{1}{\gamma^{\ast}_i} \right) \cdot e^{-\gamma^{\ast}_i {t^{\ast}}^k_j} + \xi_i \cdot e^{-{t^{\ast}}^k_j},& \gamma^{\ast}_i \neq 1 \\ 1 + (s^{\ast}_{0i} - 1) e^{-{t^{\ast}}^k_j} + (u^{\ast}_{0i} - 1) {t^{\ast}}^k_j e^{-{t^{\ast}}^k_j},& \gamma^{\ast}_i = 1 \\ \end{cases}, \label{eq-s-star-model-cell} \\ \hat{u}^{\ast}{}^k_{ij} &\sim \text{Normal}({u^{\ast}}^k_{ij}, \sigma_{ui}^2), \label{eq-u-obs-cell} \\ \hat{s}^{\ast}{}^k_{ij} &\sim \text{Normal}({s^{\ast}}^k_{ij}, \sigma_{si}^2). \label{eq-s-obs-cell} \end{align}\]

The model description in (\(\ref{eq-init-conds-priors-gene}\)\(\ref{eq-s-obs-cell}\)) can be associated to the plate diagram in Figure 4.

Code
import daft
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 16
plt.rcParams["text.usetex"] = True

pgm = daft.PGM(line_width=1.2)

# hyperparameters
pgm.add_node("mu_init", r"$\mu_{0}$", 0.5, 6, fixed=True)
pgm.add_node("sigma_init", r"$\sigma_{0}^2$", 1.5, 6, fixed=True)
pgm.add_node("mu_gamma", r"$\mu_{\gamma}$", 2.5, 6, fixed=True)
pgm.add_node("sigma_gamma", r"$\sigma_{\gamma}^2$", 3.5, 6, fixed=True)
pgm.add_node("mu_sigma", r"$\mu_{\sigma}$", 4.5, 6, fixed=True)
pgm.add_node("sigma_sigma", r"$\sigma_{\sigma}^2$", 5.5, 6, fixed=True)

# latent variables for gene-specific parameters
pgm.add_node("u_star_0i", r"$u^{\ast}_{0i}$", 1, 5)
pgm.add_node("s_star_0i", r"$s^{\ast}_{0i}$", 2, 5)
pgm.add_node("gamma_star_i", r"$\gamma^{\ast}_i$", 3, 5)
pgm.add_node("sigma_ui", r"$\sigma_{ui}$", 4, 5)
pgm.add_node("sigma_si", r"$\sigma_{si}$", 5, 5)

# latent variables for cell-specific outcomes
pgm.add_node(
    "u_star_ij",
    r"${u^{\ast}}^k_{ij}$",
    2,
    3.8,
    scale=1.0,
    shape="rectangle",
)
pgm.add_node(
    "s_star_ij",
    r"${s^{\ast}}^k_{ij}$",
    4,
    3.8,
    scale=1.0,
    shape="rectangle",
)

# observed data
pgm.add_node(
    "t_star_j",
    r"${t^{\ast}}^k_j$",
    5.9,
    3.1,
    observed=True,
    shape="rectangle",
)
pgm.add_node(
    "u_obs_ij",
    r"$\hat{u}^{\ast}{}^{k}_{ij}$",
    2,
    2.4,
    scale=1.0,
    observed=True,
)
pgm.add_node(
    "s_obs_ij",
    r"$\hat{s}^{\ast}{}^{k}_{ij}$",
    4,
    2.4,
    scale=1.0,
    observed=True,
)

# edges
edge_params = {"head_length": 0.3, "head_width": 0.25, "lw": 0.7}
pgm.add_edge("mu_init", "u_star_0i", plot_params=edge_params)
pgm.add_edge("sigma_init", "u_star_0i", plot_params=edge_params)
pgm.add_edge("mu_init", "s_star_0i", plot_params=edge_params)
pgm.add_edge("sigma_init", "s_star_0i", plot_params=edge_params)
pgm.add_edge("mu_gamma", "gamma_star_i", plot_params=edge_params)
pgm.add_edge("sigma_gamma", "gamma_star_i", plot_params=edge_params)
pgm.add_edge("mu_sigma", "sigma_ui", plot_params=edge_params)
pgm.add_edge("sigma_sigma", "sigma_ui", plot_params=edge_params)
pgm.add_edge("mu_sigma", "sigma_si", plot_params=edge_params)
pgm.add_edge("sigma_sigma", "sigma_si", plot_params=edge_params)

pgm.add_edge("u_star_0i", "u_star_ij", plot_params=edge_params)
pgm.add_edge("s_star_0i", "s_star_ij", plot_params=edge_params)
pgm.add_edge("u_star_0i", "s_star_ij", plot_params=edge_params)
pgm.add_edge("gamma_star_i", "s_star_ij", plot_params=edge_params)

pgm.add_edge("u_star_ij", "u_obs_ij", plot_params=edge_params)
pgm.add_edge("s_star_ij", "s_obs_ij", plot_params=edge_params)
pgm.add_edge("sigma_ui", "u_obs_ij", plot_params=edge_params)
pgm.add_edge("sigma_si", "s_obs_ij", plot_params=edge_params)

pgm.add_edge("t_star_j", "u_star_ij", plot_params=edge_params)
pgm.add_edge("t_star_j", "s_star_ij", plot_params=edge_params)

# plates
pgm.add_plate(
    [0.4, 1.0, 5, 4.5],
    label=r"$i \in \{1, \ldots, G\}$",
    shift=-0.1,
    fontsize=12,
)
pgm.add_plate(
    [0.8, 1.4, 5.9, 3.2],
    label=r"$j \in \{1, \ldots, N\}$",
    shift=-0.1,
    fontsize=12,
)
pgm.add_plate(
    [1.2, 1.8, 5.2, 2.5],
    label=r"$k \in \{1, \ldots, K_j\}$",
    shift=-0.1,
    fontsize=12,
)

pgm.render()
Figure 4: Graphical model for the transcription-splicing-degradation model.

Of course, in general, we will not have access to the analytical solutions, so we will need to simulate the system and then perform inference on the latent variables as suggested in (\(\ref{eq-init-conds-priors-num}\)\(\ref{eq-s-obs-num}\)).

\[\begin{align} u^{\ast}_0, s^{\ast}_0 &\sim \text{LogNormal}(\mu_{0}, \sigma_{0}^2) \label{eq-init-conds-priors-num}, \\ \gamma^{\ast} &\sim \text{LogNormal}(\mu_{\gamma}, \sigma_{\gamma}^2) \label{eq-gamma-prior-num}, \\ \sigma_u, \sigma_s &\sim \text{HalfNormal}(\mu_{\sigma}, \sigma_{\sigma}^2) \label{eq-noise-std-priors-num}, \\ (u^{\ast}, s^{\ast}) &= \text{NumericalSolver}\left(\frac{du^{\ast}}{dt^{\ast}}, \frac{ds^{\ast}}{dt^{\ast}}, u^{\ast}_0, s^{\ast}_0, \gamma^{\ast}, t^{\ast}\right) \label{eq-numerical-solution}, \\ \hat{u}^{\ast} &\sim \text{Normal}(u^{\ast}, \sigma_u^2) \label{eq-u-obs-num}, \\ \hat{s}^{\ast} &\sim \text{Normal}(s^{\ast}, \sigma_s^2) \label{eq-s-obs-num}. \end{align}\]

Where \(\frac{du^{\ast}}{dt^{\ast}}\) and \(\frac{ds^{\ast}}{dt^{\ast}}\), are given by (\(\ref{eq-dustardtstar}\)-\(\ref{eq-dsstardtstar}\)) and noting that we suppress the indices from (\(\ref{eq-init-conds-priors-gene}\)\(\ref{eq-s-obs-cell}\)) for brevity.

If, for example, we only observe one time point for each cell, then Figure 4 and its associated description in (\(\ref{eq-init-conds-priors-gene}\)\(\ref{eq-s-obs-cell}\)) reduces to the graphical model in Figure 5.

Code
import daft
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "serif"
plt.rcParams["font.size"] = 16
plt.rcParams["text.usetex"] = True

pgm = daft.PGM(line_width=1.2)

# hyperparameters
pgm.add_node("mu_init", r"$\mu_{0}$", 0.5, 6, fixed=True)
pgm.add_node("sigma_init", r"$\sigma_{0}^2$", 1.5, 6, fixed=True)
pgm.add_node("mu_gamma", r"$\mu_{\gamma}$", 2.5, 6, fixed=True)
pgm.add_node("sigma_gamma", r"$\sigma_{\gamma}^2$", 3.5, 6, fixed=True)
pgm.add_node("mu_sigma", r"$\mu_{\sigma}$", 4.5, 6, fixed=True)
pgm.add_node("sigma_sigma", r"$\sigma_{\sigma}^2$", 5.5, 6, fixed=True)

# latent variables for gene-specific parameters
pgm.add_node("u_star_0i", r"$u^{\ast}_{0i}$", 1, 5)
pgm.add_node("s_star_0i", r"$s^{\ast}_{0i}$", 2, 5)
pgm.add_node("gamma_star_i", r"$\gamma^{\ast}_i$", 3, 5)
pgm.add_node("sigma_ui", r"$\sigma_{ui}$", 4, 5)
pgm.add_node("sigma_si", r"$\sigma_{si}$", 5, 5)

# latent variables for cell-specific outcomes
pgm.add_node(
    "u_star_ij",
    r"$u^{\ast}_{ij}$",
    2,
    4,
    scale=1.0,
    shape="rectangle",
)
pgm.add_node(
    "s_star_ij",
    r"$s^{\ast}_{ij}$",
    4,
    4,
    scale=1.0,
    shape="rectangle",
)

# observed data
pgm.add_node(
    "t_star_j",
    r"$t^{\ast}_j$",
    6.0,
    3.25,
    observed=True,
    shape="rectangle",
)
pgm.add_node(
    "u_obs_ij",
    r"$\hat{u}^{\ast}_{ij}$",
    2,
    2.5,
    scale=1.0,
    observed=True,
)
pgm.add_node(
    "s_obs_ij",
    r"$\hat{s}^{\ast}_{ij}$",
    4,
    2.5,
    scale=1.0,
    observed=True,
)

# edges
edge_params = {"head_length": 0.3, "head_width": 0.25, "lw": 0.7}
pgm.add_edge("mu_init", "u_star_0i", plot_params=edge_params)
pgm.add_edge("sigma_init", "u_star_0i", plot_params=edge_params)
pgm.add_edge("mu_init", "s_star_0i", plot_params=edge_params)
pgm.add_edge("sigma_init", "s_star_0i", plot_params=edge_params)
pgm.add_edge("mu_gamma", "gamma_star_i", plot_params=edge_params)
pgm.add_edge("sigma_gamma", "gamma_star_i", plot_params=edge_params)
pgm.add_edge("mu_sigma", "sigma_ui", plot_params=edge_params)
pgm.add_edge("sigma_sigma", "sigma_ui", plot_params=edge_params)
pgm.add_edge("mu_sigma", "sigma_si", plot_params=edge_params)
pgm.add_edge("sigma_sigma", "sigma_si", plot_params=edge_params)

pgm.add_edge("u_star_0i", "u_star_ij", plot_params=edge_params)
pgm.add_edge("s_star_0i", "s_star_ij", plot_params=edge_params)
pgm.add_edge("u_star_0i", "s_star_ij", plot_params=edge_params)
pgm.add_edge("gamma_star_i", "s_star_ij", plot_params=edge_params)

pgm.add_edge("u_star_ij", "u_obs_ij", plot_params=edge_params)
pgm.add_edge("s_star_ij", "s_obs_ij", plot_params=edge_params)
pgm.add_edge("sigma_ui", "u_obs_ij", plot_params=edge_params)
pgm.add_edge("sigma_si", "s_obs_ij", plot_params=edge_params)

pgm.add_edge("t_star_j", "u_star_ij", plot_params=edge_params)
pgm.add_edge("t_star_j", "s_star_ij", plot_params=edge_params)

# plates
pgm.add_plate(
    [0.5, 1.2, 5, 4.4],
    label=r"$i \in \{1, \ldots, G\}$",
    shift=-0.1,
    fontsize=12,
)
pgm.add_plate(
    [1.0, 1.8, 5.5, 2.75],
    label=r"$j \in \{1, \ldots, N\}$",
    shift=-0.1,
    fontsize=12,
)

pgm.render()
Figure 5: Graphical model for the transcription-splicing-degradation model observed at a single time point.

Now we can set up some sample data and perform inference on the latent variables

Code
import arviz as az
from pyrovelocity.models._deterministic_inference import (
    generate_posterior_inference_data,
    generate_prior_inference_data,
    generate_test_data_for_deterministic_model_inference,
    plot_sample_phase_portraits,
    plot_sample_trajectories,
    plot_sample_trajectories_with_percentiles,
    print_inference_data_structure,
)

(
    times,
    data,
    num_genes,
    num_cells,
    num_timepoints,
    num_modalities,
) = generate_test_data_for_deterministic_model_inference(
    num_genes=1,
    num_cells=20,
    num_timepoints=1,
    num_modalities=2,
    noise_levels=(0.001, 0.001),
)

num_chains = 1
num_samples = 10
num_warmup = 10

idata_prior = generate_prior_inference_data(
    times=times,
    data=data,
    num_chains=num_chains,
    num_samples=num_samples,
    num_genes=num_genes,
    num_cells=num_cells,
    num_timepoints=num_timepoints,
    num_modalities=num_modalities,
)

idata_posterior = generate_posterior_inference_data(
    times=times,
    data=data,
    num_chains=num_chains,
    num_samples=num_samples,
    num_genes=num_genes,
    num_cells=num_cells,
    num_timepoints=num_timepoints,
    num_modalities=num_modalities,
    num_warmup=num_warmup,
)

                             mean       std    median      5.0%     95.0%     n_eff     r_hat
               gamma[0]      0.79      0.08      0.77      0.69      0.92      4.19      1.39
initial_conditions[0,0]      0.10      0.01      0.10      0.08      0.11      5.14      1.12
initial_conditions[0,1]      0.10      0.01      0.10      0.09      0.12      5.73      1.26
               sigma[0]      0.08      0.04      0.06      0.04      0.14      2.69      1.99
               sigma[1]      0.18      0.06      0.14      0.12      0.27      2.85      1.95

Number of divergences: 4
[21:26:08] INFO     pyrovelocity.models._deterministic_inference Generated test data tensor shape: (1, 20, 1, 2)   
(a) Test data
           INFO     pyrovelocity.models._deterministic_inference Generated test time array shape: (20, 1)          
(b)
[21:26:09] INFO     pyrovelocity.models._deterministic_inference                                                   
                    Prior Inference Data                                                                           
                                                                                                                   
                    Overview of InferenceData structure:                                                           
                                                                                                                   
                    Group: posterior_predictive                                                                    
                      Variables and their dimensions:                                                              
                      observations: (chain=1, draw=10, genes=1, cells=20, timepoints=1, modalities=2)              
                      times: (chain=1, draw=10, cells=20, timepoints=1)                                            
                                                                                                                   
                                                                                                                   
                    Group: prior                                                                                   
                      Variables and their dimensions:                                                              
                      gamma: (chain=1, draw=10, genes=1)                                                           
                      initial_conditions: (chain=1, draw=10, genes=1, modalities=2)                                
                      observations: (chain=1, draw=10, genes=1, cells=20, timepoints=1, modalities=2)              
                      sigma: (chain=1, draw=10, modalities=2)                                                      
                      times: (chain=1, draw=10, cells=20, timepoints=1)                                            
                                                                                                                   
                                                                                                                   
                    Group: observed_data                                                                           
                      Variables and their dimensions:                                                              
                      times: (cells=20, timepoints=1)                                                              
                      observations: (genes=1, cells=20, timepoints=1, modalities=2)                                
                                                                                                                   
(c)
[21:26:17] INFO     pyrovelocity.models._deterministic_inference                                                   
                    Posterior Inference Data                                                                       
                                                                                                                   
                    Overview of InferenceData structure:                                                           
                                                                                                                   
                    Group: posterior                                                                               
                      Variables and their dimensions:                                                              
                      gamma: (chain=1, draw=10, genes=1)                                                           
                      initial_conditions: (chain=1, draw=10, genes=1, modalities=2)                                
                      sigma: (chain=1, draw=10, modalities=2)                                                      
                                                                                                                   
                                                                                                                   
                    Group: posterior_predictive                                                                    
                      Variables and their dimensions:                                                              
                      observations: (chain=1, draw=10, genes=1, cells=20, timepoints=1, modalities=2)              
                      times: (chain=1, draw=10, cells=20, timepoints=1)                                            
                                                                                                                   
                                                                                                                   
                    Group: log_likelihood                                                                          
                      Variables and their dimensions:                                                              
                      observations: (chain=1, draw=10, genes=1, cells=20, timepoints=1, modalities=2)              
                      times: (chain=1, draw=10, cells=20, timepoints=1)                                            
                                                                                                                   
                                                                                                                   
                    Group: sample_stats                                                                            
                      Variables and their dimensions:                                                              
                      diverging: (chain=1, draw=10)                                                                
                                                                                                                   
                                                                                                                   
                    Group: prior                                                                                   
                      Variables and their dimensions:                                                              
                      gamma: (chain=1, draw=10, genes=1)                                                           
                      initial_conditions: (chain=1, draw=10, genes=1, modalities=2)                                
                      sigma: (chain=1, draw=10, modalities=2)                                                      
                                                                                                                   
                                                                                                                   
                    Group: prior_predictive                                                                        
                      Variables and their dimensions:                                                              
                      observations: (chain=1, draw=10, genes=1, cells=20, timepoints=1, modalities=2)              
                      times: (chain=1, draw=10, cells=20, timepoints=1)                                            
                                                                                                                   
                                                                                                                   
                    Group: observed_data                                                                           
                      Variables and their dimensions:                                                              
                      times: (cells=20, timepoints=1)                                                              
                      observations: (genes=1, cells=20, timepoints=1, modalities=2)                                
                                                                                                                   
(d)
Figure 6

We can plot the prior

Code
light_gray = "#bcbcbc"
variables = ["initial_conditions", "gamma", "sigma"]
for var in variables:
    az.plot_posterior(
        idata_prior,
        var_names=[var],
        group="prior",
        kind="hist",
        color=light_gray,
        round_to=2,
    )
(a) Prior distributions for the transcription-splicing-degradation model.
(b)
(c)
Figure 7

and posterior distributions for the latent variables

Code
light_gray = "#bcbcbc"
variables = ["initial_conditions", "gamma", "sigma"]
for var in variables:
    az.plot_posterior(
        idata_posterior,
        var_names=[var],
        group="posterior",
        kind="hist",
        color=light_gray,
        round_to=2,
    )
(a) Posterior distributions for the transcription-splicing-degradation model.
(b)
(c)
Figure 8

We can plot posterior predictive trajectories

Code
figs = plot_sample_trajectories(
    idata=idata_posterior,
)
for fig in figs:
    fig.show()
Figure 9: Posterior predictive trajectories for the transcription-splicing-degradation model.

trajectories in the phase portrait space

Code
figs = plot_sample_phase_portraits(
    idata=idata_posterior,
    colormap_name="RdBu",
)
for fig in figs:
    fig.show()
Figure 10: Posterior predictive phase portraits for the transcription-splicing-degradation model.

as well as estimates for the distribution of the pre-mRNA and mRNA

Code
figs = plot_sample_trajectories_with_percentiles(
    idata=idata_posterior,
)
for fig in figs:
    fig.show()
Figure 11: Posterior predictive distribution for the transcription-splicing-degradation model.

we can also plot the prior predictive distribution

Code
figs = plot_sample_trajectories_with_percentiles(
    idata=idata_posterior,
)
for fig in figs:
    fig.show()
Figure 12: Prior predictive distribution for the transcription-splicing-degradation model.

References

1.
2.
Alon, U.: An introduction to systems biology: Design principles of biological circuits. CRC Press LLC (2019)
3.
Zeisel, A., Köstler, W.J., Molotski, N., Tsai, J.M., Krauthgamer, R., Jacob‐Hirsch, J., Rechavi, G., Soen, Y., Jung, S., Yarden, Y., Domany, E.: Coupled pre‐mRNA and mRNA dynamics unveil operational strategies underlying transcriptional responses to stimuli. Mol. Syst. Biol. 7, 529 (2011). https://doi.org/10.1038/msb.2011.62
4.
La Manno, G., Soldatov, R., Zeisel, A., RNA velocity authors, Linnarsson, S., Kharchenko, P.V.: RNA velocity of single cells. Nature. (2018). https://doi.org/10.1038/s41586-018-0414-6
5.
Bergen, V., Lange, M., Peidli, S., Wolf, F.A., Theis, F.J.: Generalizing RNA velocity to transient cell states through dynamical modeling. Nat. Biotechnol. 38, 1408–1414 (2020). https://doi.org/10.1038/s41587-020-0591-3
6.
Cantwell, B.J.: Introduction to symmetry analysis. Cambridge University Press (2002)
7.
Phillips, R., Kondev, J., Theriot, J., Garcia, H.: Physical biology of the cell. Garland Science (2012)
8.
Ham, L., Schnoerr, D., Brackston, R.D., Stumpf, M.P.H.: Exactly solvable models of stochastic gene expression. J. Chem. Phys. 152, 144106 (2020). https://doi.org/10.1063/1.5143540
9.
Cao, Z., Grima, R.: Analytical distributions for detailed models of stochastic gene expression in eukaryotic cells. Proc. Natl. Acad. Sci. U. S. A. 117, 4682–4692 (2020). https://doi.org/10.1073/pnas.1910888117
10.
Bohrer, C.H., Larson, D.R.: The stochastic genome and its role in gene expression. Cold Spring Harb. Perspect. Biol. 13, (2021). https://doi.org/10.1101/cshperspect.a040386
11.
Pegoraro, F.: The dilation group and dimensionless quantities in classical and relativistic hydrodynamics. Meccanica. 8, 216–222 (1973). https://doi.org/10.1007/BF02342407