Examples

Orbit Fitting

This example demonstrates how to use sgp4jax’s differentiable SGP4 propagator to fit orbital elements to noisy position observations.

Because sgp4jax is built on JAX, we get automatic differentiation for free — enabling gradient-based optimization and analytic uncertainty estimation.

Workflow:

  1. Generate synthetic observations from a known TLE with Gaussian noise

  2. Define a forward model mapping orbital parameters → predicted positions

  3. Fit 7 orbital parameters using JAX’s built-in BFGS optimizer

  4. Estimate parameter uncertainties via Fisher information

  1#!/usr/bin/env python
  2"""Orbit Fitting with sgp4jax
  3==============================
  4
  5This example demonstrates how to use sgp4jax's **differentiable SGP4
  6propagator** to fit orbital elements to noisy position observations.
  7
  8Because sgp4jax is built on JAX, we get automatic differentiation for
  9free — enabling gradient-based optimization and analytic uncertainty
 10estimation.
 11
 12Workflow:
 13
 141. Generate synthetic observations from a known TLE with Gaussian noise
 152. Define a forward model mapping orbital parameters → predicted positions
 163. Fit 7 orbital parameters using JAX's built-in BFGS optimizer
 174. Estimate parameter uncertainties via Fisher information
 18"""
 19
 20# %%
 21# Imports
 22# -------
 23
 24import jax
 25import jax.numpy as jnp
 26from jax.scipy.optimize import minimize as jax_minimize
 27
 28import sgp4jax
 29from sgp4jax import WGS72, tle_to_satrec, propagate
 30from sgp4jax._sgp4init import sgp4init
 31
 32
 33# %%
 34# 1. Generate Synthetic Observations
 35# -----------------------------------
 36#
 37# We start from a known ISS TLE as ground truth, propagate it to 50 time
 38# points over one day (1440 minutes), and add Gaussian noise with σ = 1 km
 39# to the positions.  This noise level is realistic for ground-based radar
 40# tracking.
 41
 42line1 = "1 25544U 98067A   24045.51782528  .00016717  00000-0  10270-3 0  9006"
 43line2 = "2 25544  51.6400  10.2827 0003856 197.0300 163.0590 15.49560044439368"
 44
 45sat_true = tle_to_satrec(line1, line2)
 46
 47times = jnp.linspace(0.0, 1440.0, 50)
 48
 49r_true, v_true, errs = jax.vmap(propagate, (None, 0))(sat_true, times)
 50print(f"Propagated {len(times)} time steps, positions shape: {r_true.shape}")
 51
 52sigma = 1.0  # km
 53key = jax.random.PRNGKey(42)
 54noise = sigma * jax.random.normal(key, shape=r_true.shape)
 55r_obs = r_true + noise
 56
 57print(f"RMS noise: {jnp.sqrt(jnp.mean(noise**2)):.3f} km")
 58
 59
 60# %%
 61# 2. The Forward Model
 62# ---------------------
 63#
 64# We define a function that takes 7 orbital parameters, builds a ``SatRec``
 65# via ``sgp4init``, and propagates to all observation times.
 66#
 67# ========  ============  ================================
 68# Index     Parameter     Description
 69# ========  ============  ================================
 70# 0         ``inclo``     Inclination (rad)
 71# 1         ``nodeo``     Right ascension of ascending node (rad)
 72# 2         ``ecco``      Eccentricity
 73# 3         ``argpo``     Argument of perigee (rad)
 74# 4         ``mo``        Mean anomaly (rad)
 75# 5         ``no_kozai``  Mean motion (rad/min), Kozai
 76# 6         ``bstar``     Drag coefficient (B*)
 77# ========  ============  ================================
 78
 79def predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times):
 80    """Forward model: orbital parameters -> predicted positions."""
 81    inclo = params[0]
 82    nodeo = params[1]
 83    ecco = params[2]
 84    argpo = params[3]
 85    mo = params[4]
 86    no_kozai = params[5]
 87    bstar = params[6]
 88
 89    sat = sgp4init(
 90        gravity, epoch, bstar,
 91        0.0, 0.0,  # ndot, nddot (fixed)
 92        ecco, argpo, inclo, mo, no_kozai, nodeo,
 93        jdsatepoch, jdsatepochF,
 94    )
 95    r, v, err = jax.vmap(propagate, (None, 0))(sat, times)
 96    return r  # (n_times, 3)
 97
 98
 99# %%
100# 3. Loss Function
101# -----------------
102#
103# Weighted sum of squared residuals:
104#
105# .. math::
106#
107#    \mathcal{L}(\boldsymbol{\theta})
108#    = \frac{1}{2\sigma^2}
109#      \sum_{i=1}^{N}
110#      \|\mathbf{r}_\text{pred}(t_i;\boldsymbol{\theta})
111#        - \mathbf{r}_\text{obs}(t_i)\|^2
112
113def loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF, times, r_obs, sigma):
114    """Weighted sum of squared residuals."""
115    r_pred = predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times)
116    residuals = r_pred - r_obs
117    return 0.5 * jnp.sum(residuals**2) / sigma**2
118
119
120# %%
121# 4. Initial Guess & Optimization
122# --------------------------------
123#
124# We start from a slightly perturbed version of the true parameters and
125# use JAX's built-in BFGS optimizer (``jax.scipy.optimize.minimize``).
126#
127# Parameter scaling is critical: we normalize by the expected perturbation
128# size so that the BFGS initial Hessian approximation (identity) produces
129# conservative step sizes across all parameters.
130
131true_params = jnp.array([
132    sat_true.inclo,
133    sat_true.nodeo,
134    sat_true.ecco,
135    sat_true.argpo,
136    sat_true.mo,
137    sat_true.no_kozai,
138    sat_true.bstar,
139])
140
141gravity = WGS72
142epoch = float(sat_true.jdsatepoch) + float(sat_true.jdsatepochF) - 2433281.5
143jdsatepoch = float(sat_true.jdsatepoch)
144jdsatepochF = float(sat_true.jdsatepochF)
145
146key2 = jax.random.PRNGKey(123)
147perturbation = jnp.array([1e-4, 1e-4, 1e-5, 1e-4, 1e-4, 1e-6, 1e-6])
148x0 = true_params + perturbation * jax.random.normal(key2, shape=(7,))
149
150param_names = ["inclo", "nodeo", "ecco", "argpo", "mo", "no_kozai", "bstar"]
151
152# Parameter scaling: normalize by perturbation magnitude so that a
153# unit step in the scaled space corresponds to a perturbation-sized
154# step in physical space.  This keeps the BFGS line search stable.
155param_scale = perturbation
156
157
158def scaled_loss(x_scaled, param_scale, gravity, epoch, jdsatepoch,
159                jdsatepochF, times, r_obs, sigma):
160    """Loss in normalized parameter space."""
161    params = x_scaled * param_scale
162    return loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF,
163                   times, r_obs, sigma)
164
165
166x0_scaled = x0 / param_scale
167
168result = jax_minimize(
169    scaled_loss, x0_scaled,
170    args=(param_scale, gravity, epoch, jdsatepoch, jdsatepochF,
171          times, r_obs, sigma),
172    method="BFGS",
173)
174
175params_fit = result.x * param_scale
176print(f"\nOptimization finished: fun_val = {float(result.fun):.4f}, "
177      f"nit = {int(result.nit)}")
178print()
179print("Parameter        Initial          Optimized        True")
180print("-" * 70)
181for name, xi, fi, ti in zip(param_names, x0, params_fit, true_params):
182    print(f"{name:12s}  {float(xi):14.8f}  {float(fi):14.8f}  {float(ti):14.8f}")
183
184
185# %%
186# 5. Fisher Information & Parameter Uncertainties
187# -------------------------------------------------
188#
189# With the Jacobian of the forward model we estimate parameter
190# uncertainties via the Fisher information matrix:
191#
192# .. math::
193#
194#    \mathbf{F} = \frac{1}{\sigma^2}\,\mathbf{J}^\top\mathbf{J},
195#    \qquad
196#    \mathrm{Cov}(\boldsymbol{\theta}) \approx \mathbf{F}^{-1}
197#
198# The 1-σ uncertainties are
199# :math:`\sqrt{\mathrm{diag}(\mathbf{F}^{-1})}`.
200
201jacobian_fn = jax.jit(jax.jacobian(predict_positions))
202J_full = jacobian_fn(params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
203
204n_times = times.shape[0]
205J = J_full.reshape(n_times * 3, 7)
206
207F = J.T @ J / sigma**2
208cov = jnp.linalg.inv(F)
209uncertainties = jnp.sqrt(jnp.diag(cov))
210
211print("\nParameter     Best Fit          1-sigma Uncertainty")
212print("-" * 55)
213for name, fi, ui in zip(param_names, params_fit, uncertainties):
214    print(f"{name:12s}  {float(fi):14.8f}  +/- {float(ui):.2e}")
215
216
217# %%
218# 6. Results
219# ----------
220#
221# Visualize the position residuals and the parameter correlation matrix.
222
223try:
224    import matplotlib.pyplot as plt
225
226    r_fit = predict_positions(
227        params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
228    residuals = r_fit - r_obs
229
230    fig, axes = plt.subplots(2, 1, figsize=(10, 8))
231
232    # Plot 1: Position residuals
233    ax = axes[0]
234    for i, label in enumerate(["X", "Y", "Z"]):
235        ax.plot(times, residuals[:, i], ".", label=label, markersize=4)
236    ax.axhline(sigma, color="gray", ls="--", alpha=0.5, label=f"$\\pm${sigma} km")
237    ax.axhline(-sigma, color="gray", ls="--", alpha=0.5)
238    ax.set_xlabel("Time since epoch (min)")
239    ax.set_ylabel("Residual (km)")
240    ax.set_title("Position Residuals (fit - observed)")
241    ax.legend()
242    ax.grid(True, alpha=0.3)
243
244    # Plot 2: Correlation matrix
245    ax = axes[1]
246    std = jnp.sqrt(jnp.diag(cov))
247    corr = cov / jnp.outer(std, std)
248    im = ax.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
249    ax.set_xticks(range(7))
250    ax.set_xticklabels(param_names, rotation=45, ha="right")
251    ax.set_yticks(range(7))
252    ax.set_yticklabels(param_names)
253    ax.set_title("Parameter Correlation Matrix")
254    fig.colorbar(im, ax=ax, shrink=0.8)
255
256    plt.tight_layout()
257    plt.show()
258
259except ImportError:
260    print("(matplotlib not available — skipping plots)")
261
262
263# %%
264# Summary
265# -------
266#
267# This example demonstrated a complete orbit determination workflow:
268#
269# 1. **Synthetic data** — propagated a known TLE and added 1 km noise
270# 2. **Forward model** — ``sgp4init`` + ``propagate`` map orbital
271#    parameters to positions
272# 3. **BFGS optimization** — JAX automatic differentiation provides
273#    exact gradients through the entire SGP4 computation
274# 4. **Uncertainty estimation** — the Fisher information matrix from
275#    the Jacobian gives 1-σ parameter uncertainties and correlations
276#
277# The key advantage of sgp4jax is **differentiability**: gradients,
278# Jacobians, and Hessians are computed automatically and efficiently,
279# enabling gradient-based optimization and rigorous uncertainty
280# quantification without finite differences.

7×7 Element-Space Prior Covariance

This example shows how to build a full 7×7 prior covariance matrix over the SGP4 element vector (inclo, nodeo, ecco, argpo, mo, no_kozai, bstar) at a target time one day after TLE epoch.

The approach combines three functions:

  1. tle_ric_covariance() — empirical 6×6 RIC position/velocity covariance based on TLE age and drag coefficient.

  2. cov_ric_to_elements() — transforms the RIC covariance to 6-element space (inclo, nodeo, ecco, argpo, mo, no_kozai) via the Keplerian Jacobian. This is a square, full-rank transform.

  3. tle_bstar_sigma() — empirical 1-σ for bstar based on TLE age and atmospheric-density variability, appended as an independent 7th block.

bstar is treated as independent of the Keplerian elements in the prior. This is physically justified: drag is a satellite property independent of orbital geometry. Any bstar ↔ (mo, no_kozai) correlations emerge naturally from the likelihood during fitting and should not be imposed by the prior.

The resulting matrix is a symmetric positive-definite (7, 7) covariance suitable as a Gaussian prior θ ~ N(θ_tle, Σ) in Bayesian TLE fitting.

  1"""Build a 7×7 element-space prior covariance from a TLE.
  2
  3This example shows how to combine :func:`~sgp4jax.tle_ric_covariance`,
  4:func:`~sgp4jax.cov_ric_to_elements`, and :func:`~sgp4jax.tle_bstar_sigma`
  5to produce a full 7×7 prior covariance over the SGP4 element vector
  6
  7    (inclo, nodeo, ecco, argpo, mo, no_kozai, bstar)
  8
  9at a target time one day after TLE epoch.  This matrix is suitable as a
 10Gaussian prior in Bayesian TLE fitting.
 11
 12Workflow
 13--------
 141.  Parse the TLE and define the target time.
 152.  Get the 6×6 RIC position/velocity covariance via the empirical
 16    Vallado-style error-growth model.
 173.  Transform to 6-element space (inclo, nodeo, ecco, argpo, mo, no_kozai)
 18    using the Keplerian Jacobian.  This is a square, full-rank transform.
 194.  Append the empirical bstar variance as an independent 7th block.
 20
 21Why not use ``cov_ric_to_elements7``?
 22~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 23The 7-element Jacobian has shape (6, 7): six state dimensions cannot
 24uniquely constrain seven parameters.  Its pseudo-inverse produces a
 25rank-deficient (rank ≤ 6) covariance with numerically artefactual
 26cross-terms.  For a *prior*, bstar independence from the Keplerian
 27elements is physically justified — drag is a satellite property
 28independent of orbital geometry.  The bstar ↔ (mo, no_kozai)
 29correlations you care about are *posterior* correlations that emerge
 30naturally from the likelihood; they should not be baked into the prior.
 31"""
 32
 33import jax.numpy as jnp
 34
 35from sgp4jax import (
 36    tle_to_satrec,
 37    tle_ric_covariance,
 38    tle_bstar_sigma,
 39    cov_ric_to_elements,
 40)
 41
 42# ---------------------------------------------------------------------------
 43# TLE for the International Space Station (example)
 44# ---------------------------------------------------------------------------
 45LINE1 = "1 25544U 98067A   24001.50000000  .00003317  00000-0  38117-4 0  9994"
 46LINE2 = "2 25544  51.6416 247.4627 0006703 130.5360  13.6717 15.50026396432903"
 47
 48sat = tle_to_satrec(LINE1, LINE2)
 49
 50# ---------------------------------------------------------------------------
 51# Target time: 1 day after TLE epoch
 52# ---------------------------------------------------------------------------
 53jd = sat.jdsatepoch + 1.0   # whole part of Julian date
 54fr = sat.jdsatepochF        # fractional part (unchanged)
 55
 56# ---------------------------------------------------------------------------
 57# Step 1 — 6×6 RIC position/velocity covariance
 58#
 59# Diagonal matrix; units km² (position block) and km²/s² (velocity block).
 60# In-track (T) growth is scaled by bstar relative to the LEO population
 61# median.
 62# ---------------------------------------------------------------------------
 63cov_ric = tle_ric_covariance(sat, jd, fr)
 64
 65print("6×6 RIC 1-σ (km, km/s):")
 66for label, var in zip(["R", "T", "N", "Ṙ", "Ṫ", "Ṅ"], jnp.diag(cov_ric)):
 67    print(f"  σ_{label} = {float(jnp.sqrt(var)):.4f}")
 68
 69# ---------------------------------------------------------------------------
 70# Step 2 — Transform to 6-element space via Keplerian Jacobian
 71#
 72# cov_ric_to_elements chains:  RIC → TEME → elements6
 73# The result is a full-rank (6, 6) covariance.
 74# ---------------------------------------------------------------------------
 75cov_el6 = cov_ric_to_elements(cov_ric, sat, jd, fr)
 76
 77print("\n6-element 1-σ:")
 78el6_labels = [
 79    "inclo     (rad)",
 80    "nodeo     (rad)",
 81    "ecco          ",
 82    "argpo     (rad)",
 83    "mo        (rad)",
 84    "no_kozai  (rad/min)",
 85]
 86for label, var in zip(el6_labels, jnp.diag(cov_el6)):
 87    print(f"  σ_{label} = {float(jnp.sqrt(var)):.3e}")
 88
 89# ---------------------------------------------------------------------------
 90# Step 3 — Empirical bstar uncertainty
 91#
 92# tle_bstar_sigma returns a scalar 1-σ based on TLE age and |bstar|.
 93# At 1 day: σ_bstar ≈ 30% + 10%·1 day = 40% of |bstar|.
 94# ---------------------------------------------------------------------------
 95sigma_bstar = tle_bstar_sigma(sat, jd, fr)
 96print(f"\nbstar = {float(sat.bstar):.3e} km⁻¹")
 97print(f"σ_bstar at Δt=1 day = {float(sigma_bstar):.3e} km⁻¹  "
 98      f"({100*float(sigma_bstar)/abs(float(sat.bstar)):.0f}%)")
 99
100# ---------------------------------------------------------------------------
101# Step 4 — Assemble the 7×7 prior covariance
102#
103# bstar is appended as an independent block: no cross-terms with the
104# Keplerian elements.  The result is symmetric positive-definite.
105# ---------------------------------------------------------------------------
106cov_prior = jnp.block([
107    [cov_el6,            jnp.zeros((6, 1))],
108    [jnp.zeros((1, 6)),  jnp.array([[sigma_bstar ** 2]])],
109])
110
111print("\n7×7 prior: 1-σ diagonal:")
112el7_labels = el6_labels + ["bstar     (km⁻¹)  "]
113for label, var in zip(el7_labels, jnp.diag(cov_prior)):
114    print(f"  σ_{label} = {float(jnp.sqrt(var)):.3e}")
115
116print("\ncov_prior shape:", cov_prior.shape)
117eigenvalues = jnp.linalg.eigvalsh(cov_prior)
118print(f"Minimum eigenvalue: {float(eigenvalues.min()):.3e}  (≈ machine-ε × max eigenvalue — numerically PD)")
119
120# ---------------------------------------------------------------------------
121# This cov_prior is suitable as a Gaussian prior:
122#
123#   θ ~ N(θ_tle, cov_prior)
124#
125# where θ = (inclo, nodeo, ecco, argpo, mo, no_kozai, bstar).
126# ---------------------------------------------------------------------------