Examples¶
Orbit Fitting¶
This example demonstrates how to use sgp4jax’s differentiable SGP4 propagator to fit orbital elements to noisy position observations.
Because sgp4jax is built on JAX, we get automatic differentiation for free — enabling gradient-based optimization and analytic uncertainty estimation.
Workflow:
Generate synthetic observations from a known TLE with Gaussian noise
Define a forward model mapping orbital parameters → predicted positions
Fit 7 orbital parameters using JAX’s built-in BFGS optimizer
Estimate parameter uncertainties via Fisher information
1#!/usr/bin/env python
2"""Orbit Fitting with sgp4jax
3==============================
4
5This example demonstrates how to use sgp4jax's **differentiable SGP4
6propagator** to fit orbital elements to noisy position observations.
7
8Because sgp4jax is built on JAX, we get automatic differentiation for
9free — enabling gradient-based optimization and analytic uncertainty
10estimation.
11
12Workflow:
13
141. Generate synthetic observations from a known TLE with Gaussian noise
152. Define a forward model mapping orbital parameters → predicted positions
163. Fit 7 orbital parameters using JAX's built-in BFGS optimizer
174. Estimate parameter uncertainties via Fisher information
18"""
19
20# %%
21# Imports
22# -------
23
24import jax
25import jax.numpy as jnp
26from jax.scipy.optimize import minimize as jax_minimize
27
28import sgp4jax
29from sgp4jax import WGS72, tle_to_satrec, propagate
30from sgp4jax._sgp4init import sgp4init
31
32
33# %%
34# 1. Generate Synthetic Observations
35# -----------------------------------
36#
37# We start from a known ISS TLE as ground truth, propagate it to 50 time
38# points over one day (1440 minutes), and add Gaussian noise with σ = 1 km
39# to the positions. This noise level is realistic for ground-based radar
40# tracking.
41
42line1 = "1 25544U 98067A 24045.51782528 .00016717 00000-0 10270-3 0 9006"
43line2 = "2 25544 51.6400 10.2827 0003856 197.0300 163.0590 15.49560044439368"
44
45sat_true = tle_to_satrec(line1, line2)
46
47times = jnp.linspace(0.0, 1440.0, 50)
48
49r_true, v_true, errs = jax.vmap(propagate, (None, 0))(sat_true, times)
50print(f"Propagated {len(times)} time steps, positions shape: {r_true.shape}")
51
52sigma = 1.0 # km
53key = jax.random.PRNGKey(42)
54noise = sigma * jax.random.normal(key, shape=r_true.shape)
55r_obs = r_true + noise
56
57print(f"RMS noise: {jnp.sqrt(jnp.mean(noise**2)):.3f} km")
58
59
60# %%
61# 2. The Forward Model
62# ---------------------
63#
64# We define a function that takes 7 orbital parameters, builds a ``SatRec``
65# via ``sgp4init``, and propagates to all observation times.
66#
67# ======== ============ ================================
68# Index Parameter Description
69# ======== ============ ================================
70# 0 ``inclo`` Inclination (rad)
71# 1 ``nodeo`` Right ascension of ascending node (rad)
72# 2 ``ecco`` Eccentricity
73# 3 ``argpo`` Argument of perigee (rad)
74# 4 ``mo`` Mean anomaly (rad)
75# 5 ``no_kozai`` Mean motion (rad/min), Kozai
76# 6 ``bstar`` Drag coefficient (B*)
77# ======== ============ ================================
78
79def predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times):
80 """Forward model: orbital parameters -> predicted positions."""
81 inclo = params[0]
82 nodeo = params[1]
83 ecco = params[2]
84 argpo = params[3]
85 mo = params[4]
86 no_kozai = params[5]
87 bstar = params[6]
88
89 sat = sgp4init(
90 gravity, epoch, bstar,
91 0.0, 0.0, # ndot, nddot (fixed)
92 ecco, argpo, inclo, mo, no_kozai, nodeo,
93 jdsatepoch, jdsatepochF,
94 )
95 r, v, err = jax.vmap(propagate, (None, 0))(sat, times)
96 return r # (n_times, 3)
97
98
99# %%
100# 3. Loss Function
101# -----------------
102#
103# Weighted sum of squared residuals:
104#
105# .. math::
106#
107# \mathcal{L}(\boldsymbol{\theta})
108# = \frac{1}{2\sigma^2}
109# \sum_{i=1}^{N}
110# \|\mathbf{r}_\text{pred}(t_i;\boldsymbol{\theta})
111# - \mathbf{r}_\text{obs}(t_i)\|^2
112
113def loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF, times, r_obs, sigma):
114 """Weighted sum of squared residuals."""
115 r_pred = predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times)
116 residuals = r_pred - r_obs
117 return 0.5 * jnp.sum(residuals**2) / sigma**2
118
119
120# %%
121# 4. Initial Guess & Optimization
122# --------------------------------
123#
124# We start from a slightly perturbed version of the true parameters and
125# use JAX's built-in BFGS optimizer (``jax.scipy.optimize.minimize``).
126#
127# Parameter scaling is critical: we normalize by the expected perturbation
128# size so that the BFGS initial Hessian approximation (identity) produces
129# conservative step sizes across all parameters.
130
131true_params = jnp.array([
132 sat_true.inclo,
133 sat_true.nodeo,
134 sat_true.ecco,
135 sat_true.argpo,
136 sat_true.mo,
137 sat_true.no_kozai,
138 sat_true.bstar,
139])
140
141gravity = WGS72
142epoch = float(sat_true.jdsatepoch) + float(sat_true.jdsatepochF) - 2433281.5
143jdsatepoch = float(sat_true.jdsatepoch)
144jdsatepochF = float(sat_true.jdsatepochF)
145
146key2 = jax.random.PRNGKey(123)
147perturbation = jnp.array([1e-4, 1e-4, 1e-5, 1e-4, 1e-4, 1e-6, 1e-6])
148x0 = true_params + perturbation * jax.random.normal(key2, shape=(7,))
149
150param_names = ["inclo", "nodeo", "ecco", "argpo", "mo", "no_kozai", "bstar"]
151
152# Parameter scaling: normalize by perturbation magnitude so that a
153# unit step in the scaled space corresponds to a perturbation-sized
154# step in physical space. This keeps the BFGS line search stable.
155param_scale = perturbation
156
157
158def scaled_loss(x_scaled, param_scale, gravity, epoch, jdsatepoch,
159 jdsatepochF, times, r_obs, sigma):
160 """Loss in normalized parameter space."""
161 params = x_scaled * param_scale
162 return loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF,
163 times, r_obs, sigma)
164
165
166x0_scaled = x0 / param_scale
167
168result = jax_minimize(
169 scaled_loss, x0_scaled,
170 args=(param_scale, gravity, epoch, jdsatepoch, jdsatepochF,
171 times, r_obs, sigma),
172 method="BFGS",
173)
174
175params_fit = result.x * param_scale
176print(f"\nOptimization finished: fun_val = {float(result.fun):.4f}, "
177 f"nit = {int(result.nit)}")
178print()
179print("Parameter Initial Optimized True")
180print("-" * 70)
181for name, xi, fi, ti in zip(param_names, x0, params_fit, true_params):
182 print(f"{name:12s} {float(xi):14.8f} {float(fi):14.8f} {float(ti):14.8f}")
183
184
185# %%
186# 5. Fisher Information & Parameter Uncertainties
187# -------------------------------------------------
188#
189# With the Jacobian of the forward model we estimate parameter
190# uncertainties via the Fisher information matrix:
191#
192# .. math::
193#
194# \mathbf{F} = \frac{1}{\sigma^2}\,\mathbf{J}^\top\mathbf{J},
195# \qquad
196# \mathrm{Cov}(\boldsymbol{\theta}) \approx \mathbf{F}^{-1}
197#
198# The 1-σ uncertainties are
199# :math:`\sqrt{\mathrm{diag}(\mathbf{F}^{-1})}`.
200
201jacobian_fn = jax.jit(jax.jacobian(predict_positions))
202J_full = jacobian_fn(params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
203
204n_times = times.shape[0]
205J = J_full.reshape(n_times * 3, 7)
206
207F = J.T @ J / sigma**2
208cov = jnp.linalg.inv(F)
209uncertainties = jnp.sqrt(jnp.diag(cov))
210
211print("\nParameter Best Fit 1-sigma Uncertainty")
212print("-" * 55)
213for name, fi, ui in zip(param_names, params_fit, uncertainties):
214 print(f"{name:12s} {float(fi):14.8f} +/- {float(ui):.2e}")
215
216
217# %%
218# 6. Results
219# ----------
220#
221# Visualize the position residuals and the parameter correlation matrix.
222
223try:
224 import matplotlib.pyplot as plt
225
226 r_fit = predict_positions(
227 params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
228 residuals = r_fit - r_obs
229
230 fig, axes = plt.subplots(2, 1, figsize=(10, 8))
231
232 # Plot 1: Position residuals
233 ax = axes[0]
234 for i, label in enumerate(["X", "Y", "Z"]):
235 ax.plot(times, residuals[:, i], ".", label=label, markersize=4)
236 ax.axhline(sigma, color="gray", ls="--", alpha=0.5, label=f"$\\pm${sigma} km")
237 ax.axhline(-sigma, color="gray", ls="--", alpha=0.5)
238 ax.set_xlabel("Time since epoch (min)")
239 ax.set_ylabel("Residual (km)")
240 ax.set_title("Position Residuals (fit - observed)")
241 ax.legend()
242 ax.grid(True, alpha=0.3)
243
244 # Plot 2: Correlation matrix
245 ax = axes[1]
246 std = jnp.sqrt(jnp.diag(cov))
247 corr = cov / jnp.outer(std, std)
248 im = ax.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
249 ax.set_xticks(range(7))
250 ax.set_xticklabels(param_names, rotation=45, ha="right")
251 ax.set_yticks(range(7))
252 ax.set_yticklabels(param_names)
253 ax.set_title("Parameter Correlation Matrix")
254 fig.colorbar(im, ax=ax, shrink=0.8)
255
256 plt.tight_layout()
257 plt.show()
258
259except ImportError:
260 print("(matplotlib not available — skipping plots)")
261
262
263# %%
264# Summary
265# -------
266#
267# This example demonstrated a complete orbit determination workflow:
268#
269# 1. **Synthetic data** — propagated a known TLE and added 1 km noise
270# 2. **Forward model** — ``sgp4init`` + ``propagate`` map orbital
271# parameters to positions
272# 3. **BFGS optimization** — JAX automatic differentiation provides
273# exact gradients through the entire SGP4 computation
274# 4. **Uncertainty estimation** — the Fisher information matrix from
275# the Jacobian gives 1-σ parameter uncertainties and correlations
276#
277# The key advantage of sgp4jax is **differentiability**: gradients,
278# Jacobians, and Hessians are computed automatically and efficiently,
279# enabling gradient-based optimization and rigorous uncertainty
280# quantification without finite differences.
7×7 Element-Space Prior Covariance¶
This example shows how to build a full 7×7 prior covariance matrix over
the SGP4 element vector (inclo, nodeo, ecco, argpo, mo, no_kozai, bstar)
at a target time one day after TLE epoch.
The approach combines three functions:
tle_ric_covariance()— empirical 6×6 RIC position/velocity covariance based on TLE age and drag coefficient.cov_ric_to_elements()— transforms the RIC covariance to 6-element space(inclo, nodeo, ecco, argpo, mo, no_kozai)via the Keplerian Jacobian. This is a square, full-rank transform.tle_bstar_sigma()— empirical 1-σ for bstar based on TLE age and atmospheric-density variability, appended as an independent 7th block.
bstar is treated as independent of the Keplerian elements in the prior. This is physically justified: drag is a satellite property independent of orbital geometry. Any bstar ↔ (mo, no_kozai) correlations emerge naturally from the likelihood during fitting and should not be imposed by the prior.
The resulting matrix is a symmetric positive-definite (7, 7) covariance
suitable as a Gaussian prior θ ~ N(θ_tle, Σ) in Bayesian TLE fitting.
1"""Build a 7×7 element-space prior covariance from a TLE.
2
3This example shows how to combine :func:`~sgp4jax.tle_ric_covariance`,
4:func:`~sgp4jax.cov_ric_to_elements`, and :func:`~sgp4jax.tle_bstar_sigma`
5to produce a full 7×7 prior covariance over the SGP4 element vector
6
7 (inclo, nodeo, ecco, argpo, mo, no_kozai, bstar)
8
9at a target time one day after TLE epoch. This matrix is suitable as a
10Gaussian prior in Bayesian TLE fitting.
11
12Workflow
13--------
141. Parse the TLE and define the target time.
152. Get the 6×6 RIC position/velocity covariance via the empirical
16 Vallado-style error-growth model.
173. Transform to 6-element space (inclo, nodeo, ecco, argpo, mo, no_kozai)
18 using the Keplerian Jacobian. This is a square, full-rank transform.
194. Append the empirical bstar variance as an independent 7th block.
20
21Why not use ``cov_ric_to_elements7``?
22~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
23The 7-element Jacobian has shape (6, 7): six state dimensions cannot
24uniquely constrain seven parameters. Its pseudo-inverse produces a
25rank-deficient (rank ≤ 6) covariance with numerically artefactual
26cross-terms. For a *prior*, bstar independence from the Keplerian
27elements is physically justified — drag is a satellite property
28independent of orbital geometry. The bstar ↔ (mo, no_kozai)
29correlations you care about are *posterior* correlations that emerge
30naturally from the likelihood; they should not be baked into the prior.
31"""
32
33import jax.numpy as jnp
34
35from sgp4jax import (
36 tle_to_satrec,
37 tle_ric_covariance,
38 tle_bstar_sigma,
39 cov_ric_to_elements,
40)
41
42# ---------------------------------------------------------------------------
43# TLE for the International Space Station (example)
44# ---------------------------------------------------------------------------
45LINE1 = "1 25544U 98067A 24001.50000000 .00003317 00000-0 38117-4 0 9994"
46LINE2 = "2 25544 51.6416 247.4627 0006703 130.5360 13.6717 15.50026396432903"
47
48sat = tle_to_satrec(LINE1, LINE2)
49
50# ---------------------------------------------------------------------------
51# Target time: 1 day after TLE epoch
52# ---------------------------------------------------------------------------
53jd = sat.jdsatepoch + 1.0 # whole part of Julian date
54fr = sat.jdsatepochF # fractional part (unchanged)
55
56# ---------------------------------------------------------------------------
57# Step 1 — 6×6 RIC position/velocity covariance
58#
59# Diagonal matrix; units km² (position block) and km²/s² (velocity block).
60# In-track (T) growth is scaled by bstar relative to the LEO population
61# median.
62# ---------------------------------------------------------------------------
63cov_ric = tle_ric_covariance(sat, jd, fr)
64
65print("6×6 RIC 1-σ (km, km/s):")
66for label, var in zip(["R", "T", "N", "Ṙ", "Ṫ", "Ṅ"], jnp.diag(cov_ric)):
67 print(f" σ_{label} = {float(jnp.sqrt(var)):.4f}")
68
69# ---------------------------------------------------------------------------
70# Step 2 — Transform to 6-element space via Keplerian Jacobian
71#
72# cov_ric_to_elements chains: RIC → TEME → elements6
73# The result is a full-rank (6, 6) covariance.
74# ---------------------------------------------------------------------------
75cov_el6 = cov_ric_to_elements(cov_ric, sat, jd, fr)
76
77print("\n6-element 1-σ:")
78el6_labels = [
79 "inclo (rad)",
80 "nodeo (rad)",
81 "ecco ",
82 "argpo (rad)",
83 "mo (rad)",
84 "no_kozai (rad/min)",
85]
86for label, var in zip(el6_labels, jnp.diag(cov_el6)):
87 print(f" σ_{label} = {float(jnp.sqrt(var)):.3e}")
88
89# ---------------------------------------------------------------------------
90# Step 3 — Empirical bstar uncertainty
91#
92# tle_bstar_sigma returns a scalar 1-σ based on TLE age and |bstar|.
93# At 1 day: σ_bstar ≈ 30% + 10%·1 day = 40% of |bstar|.
94# ---------------------------------------------------------------------------
95sigma_bstar = tle_bstar_sigma(sat, jd, fr)
96print(f"\nbstar = {float(sat.bstar):.3e} km⁻¹")
97print(f"σ_bstar at Δt=1 day = {float(sigma_bstar):.3e} km⁻¹ "
98 f"({100*float(sigma_bstar)/abs(float(sat.bstar)):.0f}%)")
99
100# ---------------------------------------------------------------------------
101# Step 4 — Assemble the 7×7 prior covariance
102#
103# bstar is appended as an independent block: no cross-terms with the
104# Keplerian elements. The result is symmetric positive-definite.
105# ---------------------------------------------------------------------------
106cov_prior = jnp.block([
107 [cov_el6, jnp.zeros((6, 1))],
108 [jnp.zeros((1, 6)), jnp.array([[sigma_bstar ** 2]])],
109])
110
111print("\n7×7 prior: 1-σ diagonal:")
112el7_labels = el6_labels + ["bstar (km⁻¹) "]
113for label, var in zip(el7_labels, jnp.diag(cov_prior)):
114 print(f" σ_{label} = {float(jnp.sqrt(var)):.3e}")
115
116print("\ncov_prior shape:", cov_prior.shape)
117eigenvalues = jnp.linalg.eigvalsh(cov_prior)
118print(f"Minimum eigenvalue: {float(eigenvalues.min()):.3e} (≈ machine-ε × max eigenvalue — numerically PD)")
119
120# ---------------------------------------------------------------------------
121# This cov_prior is suitable as a Gaussian prior:
122#
123# θ ~ N(θ_tle, cov_prior)
124#
125# where θ = (inclo, nodeo, ecco, argpo, mo, no_kozai, bstar).
126# ---------------------------------------------------------------------------