"""Main: run a sensitivity analysis of the Lotka-Volterra equations."""
import argparse
import time
import numpy as np
import numpy.typing
from loguru import logger
from lv_sens import data, ode, sensitivity
from lv_sens.types import NDArray_f64
[docs]
def get_parser() -> argparse.ArgumentParser:
"""Return arguments parser."""
parser = argparse.ArgumentParser(
description="""\
Run a sensitivity analysis of the Lotka-Volterra equations.
""",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"command",
choices=["forward", "inverse", "check-gradient"],
help=(
"Select one of the following actions:\n"
"- forward: perform a forward simulation\n"
"- inverse: perform an inverse simulation (parameters optimization)\n"
"- check-gradient: compare the cost function gradient computation with "
"different methods."
),
)
parser.add_argument(
"-n",
"--num-steps",
metavar="N",
type=int,
default=100000,
help="Number of time steps.",
)
parser.add_argument(
"-T",
"--time-final",
metavar="T",
type=float,
default=10.0,
help="Final time of the simulation.",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=123,
help="Random seed for measurement noise generation.",
)
parser.add_argument(
"--noise-std",
type=float,
default=0.01,
help="Measurement noise standard deviation for data generation.",
)
parser.add_argument(
"-g",
"--gradient",
type=str,
choices=["adjoint", "fd", "none"],
default="adjoint",
help="Gradient computation method for the optimization problem.",
)
parser.add_argument(
"--theta-guess",
metavar=("alpha", "beta", "delta"),
type=float,
nargs=3,
default=[1.2, 0.8, 2.4],
help="Initial guess of parameters theta for the optimization.",
)
parser.add_argument(
"--theta-ref",
metavar=("alpha", "beta", "delta"),
type=float,
nargs=3,
default=[1.5, 1.0, 3.0],
help=(
"Parameter reference values for forward problem and data generation "
"(ground truth) and gradient comparison (`check-gradient`)."
),
)
parser.add_argument(
"--initial-condition",
metavar=("x0", "y0"),
type=float,
nargs=2,
default=[1.0, 1.0],
help="Initial condition of the ODE solution.",
)
parser.add_argument(
"--plot-data", action="store_true", help="Plot measurement data"
)
return parser
[docs]
def compare_gradient_fd(
theta: NDArray_f64, x_data: NDArray_f64, dt: float, N: float, x0: NDArray_f64
) -> None:
"""Compare the adjoint-based gradient with a finite difference approximation.
Args:
theta: parameters at which to evaluate gradient
x_data: data array
dt: time step
N: number of time steps
x0: initial guess
"""
logger.info("Comparing adjoint gradient with finite difference gradient")
tic = time.perf_counter()
grad_adjoint = sensitivity.grad_adj(theta, x_data, dt, N, x0)
logger.info(f"Computed intial adjoint gradient in {time.perf_counter() - tic}s")
eps = 1e-8
tic = time.perf_counter()
grad_fd = sensitivity.grad_fd(theta, x_data, dt, N, x0, eps)
logger.info(f"Computed intial FD gradient in {time.perf_counter() - tic}s")
logger.info(f"Adjoint gradient: \t {grad_adjoint}")
logger.info(f"FD gradient: \t\t {grad_fd}")
logger.info("---")
logger.info(f"absolute error (FD vs Adj): {abs(grad_adjoint - grad_fd)}")
logger.info(
f"relative error (FD vs Adj): {abs(grad_adjoint - grad_fd) / abs(grad_fd)}"
)
[docs]
def generate_data(
theta: NDArray_f64,
dt: float,
N: int,
x0: NDArray_f64,
noise_std: float,
plot: bool = False,
) -> NDArray_f64:
"""Generate measurement data from a ground truth forward solution.
Args:
theta: parameter for reference solution
dt: time step
N: number of time steps
x0: initial guess
noise_std: noise standard deviation
plot: toggle to plot solution and data points
Returns:
data array
"""
logger.info("Generating data")
tic = time.perf_counter()
x_ref = ode.forward(theta, dt, N, x0)
logger.info(f"Forward solve - time elapsed: {time.perf_counter() - tic}s")
x_data = data.noisy_data(x_ref, noise_std)
if plot:
try:
import matplotlib.pyplot as plt
times = np.linspace(0, dt * (N - 1), num=N)
plt.plot(times, x_ref[0, :])
plt.plot(times, x_ref[1, :])
plt.plot(times, x_data[0, :], ".")
plt.plot(times, x_data[1, :], ".")
plt.show()
except ImportError:
logger.warning("matplotlib not found! skipping plots...")
return x_data
[docs]
def main() -> None:
"""Main function.
Parse command line arguments and run the chosen simulation.
"""
args = get_parser().parse_args()
np.random.seed(args.seed)
N = args.num_steps
T_end = args.time_final
dt = T_end / (N - 1)
theta = np.array(args.theta_ref)
x0 = np.array(args.initial_condition)
if args.command == "forward":
tic = time.perf_counter()
logger.info("Forward-solving the ODE")
ode.forward(theta, dt, N, x0)
logger.info(f"time elapsed: {time.perf_counter() - tic}s")
elif args.command == "inverse":
x_data = generate_data(theta, dt, N, x0, args.noise_std, plot=args.plot_data)
logger.info(f"Optimizing with cost function gradient type '{args.gradient}'")
tic = time.perf_counter()
res = sensitivity.optimize(args.theta_guess, x_data, dt, N, x0, args.gradient)
logger.info(res)
logger.info(f"Time elapsed: {time.perf_counter() - tic} s")
elif args.command == "check-gradient":
x_data = generate_data(theta, dt, N, x0, args.noise_std)
compare_gradient_fd(theta, x_data, dt, N, x0)
else:
raise Exception(f"Unknown command: {args.command}")
if __name__ == "__main__":
main()