Examples¶
Orbit Fitting¶
This example demonstrates how to use sgp4jax’s differentiable SGP4 propagator to fit orbital elements to noisy position observations.
Because sgp4jax is built on JAX, we get automatic differentiation for free — enabling gradient-based optimization and analytic uncertainty estimation.
Workflow:
Generate synthetic observations from a known TLE with Gaussian noise
Define a forward model mapping orbital parameters → predicted positions
Fit 7 orbital parameters using JAX’s built-in BFGS optimizer
Estimate parameter uncertainties via Fisher information
1#!/usr/bin/env python
2"""Orbit Fitting with sgp4jax
3==============================
4
5This example demonstrates how to use sgp4jax's **differentiable SGP4
6propagator** to fit orbital elements to noisy position observations.
7
8Because sgp4jax is built on JAX, we get automatic differentiation for
9free — enabling gradient-based optimization and analytic uncertainty
10estimation.
11
12Workflow:
13
141. Generate synthetic observations from a known TLE with Gaussian noise
152. Define a forward model mapping orbital parameters → predicted positions
163. Fit 7 orbital parameters using JAX's built-in BFGS optimizer
174. Estimate parameter uncertainties via Fisher information
18"""
19
20# %%
21# Imports
22# -------
23
24import jax
25import jax.numpy as jnp
26from jax.scipy.optimize import minimize as jax_minimize
27
28import sgp4jax
29from sgp4jax import WGS72, tle_to_satrec, propagate
30from sgp4jax._sgp4init import sgp4init
31
32
33# %%
34# 1. Generate Synthetic Observations
35# -----------------------------------
36#
37# We start from a known ISS TLE as ground truth, propagate it to 50 time
38# points over one day (1440 minutes), and add Gaussian noise with σ = 1 km
39# to the positions. This noise level is realistic for ground-based radar
40# tracking.
41
42line1 = "1 25544U 98067A 24045.51782528 .00016717 00000-0 10270-3 0 9006"
43line2 = "2 25544 51.6400 10.2827 0003856 197.0300 163.0590 15.49560044439368"
44
45sat_true = tle_to_satrec(line1, line2)
46
47times = jnp.linspace(0.0, 1440.0, 50)
48
49r_true, v_true, errs = jax.vmap(propagate, (None, 0))(sat_true, times)
50print(f"Propagated {len(times)} time steps, positions shape: {r_true.shape}")
51
52sigma = 1.0 # km
53key = jax.random.PRNGKey(42)
54noise = sigma * jax.random.normal(key, shape=r_true.shape)
55r_obs = r_true + noise
56
57print(f"RMS noise: {jnp.sqrt(jnp.mean(noise**2)):.3f} km")
58
59
60# %%
61# 2. The Forward Model
62# ---------------------
63#
64# We define a function that takes 7 orbital parameters, builds a ``SatRec``
65# via ``sgp4init``, and propagates to all observation times.
66#
67# ======== ============ ================================
68# Index Parameter Description
69# ======== ============ ================================
70# 0 ``inclo`` Inclination (rad)
71# 1 ``nodeo`` Right ascension of ascending node (rad)
72# 2 ``ecco`` Eccentricity
73# 3 ``argpo`` Argument of perigee (rad)
74# 4 ``mo`` Mean anomaly (rad)
75# 5 ``no_kozai`` Mean motion (rad/min), Kozai
76# 6 ``bstar`` Drag coefficient (B*)
77# ======== ============ ================================
78
79def predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times):
80 """Forward model: orbital parameters -> predicted positions."""
81 inclo = params[0]
82 nodeo = params[1]
83 ecco = params[2]
84 argpo = params[3]
85 mo = params[4]
86 no_kozai = params[5]
87 bstar = params[6]
88
89 sat = sgp4init(
90 gravity, epoch, bstar,
91 0.0, 0.0, # ndot, nddot (fixed)
92 ecco, argpo, inclo, mo, no_kozai, nodeo,
93 jdsatepoch, jdsatepochF,
94 )
95 r, v, err = jax.vmap(propagate, (None, 0))(sat, times)
96 return r # (n_times, 3)
97
98
99# %%
100# 3. Loss Function
101# -----------------
102#
103# Weighted sum of squared residuals:
104#
105# .. math::
106#
107# \mathcal{L}(\boldsymbol{\theta})
108# = \frac{1}{2\sigma^2}
109# \sum_{i=1}^{N}
110# \|\mathbf{r}_\text{pred}(t_i;\boldsymbol{\theta})
111# - \mathbf{r}_\text{obs}(t_i)\|^2
112
113def loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF, times, r_obs, sigma):
114 """Weighted sum of squared residuals."""
115 r_pred = predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times)
116 residuals = r_pred - r_obs
117 return 0.5 * jnp.sum(residuals**2) / sigma**2
118
119
120# %%
121# 4. Initial Guess & Optimization
122# --------------------------------
123#
124# We start from a slightly perturbed version of the true parameters and
125# use JAX's built-in BFGS optimizer (``jax.scipy.optimize.minimize``).
126#
127# Parameter scaling is critical: we normalize by the expected perturbation
128# size so that the BFGS initial Hessian approximation (identity) produces
129# conservative step sizes across all parameters.
130
131true_params = jnp.array([
132 sat_true.inclo,
133 sat_true.nodeo,
134 sat_true.ecco,
135 sat_true.argpo,
136 sat_true.mo,
137 sat_true.no_kozai,
138 sat_true.bstar,
139])
140
141gravity = WGS72
142epoch = float(sat_true.jdsatepoch) + float(sat_true.jdsatepochF) - 2433281.5
143jdsatepoch = float(sat_true.jdsatepoch)
144jdsatepochF = float(sat_true.jdsatepochF)
145
146key2 = jax.random.PRNGKey(123)
147perturbation = jnp.array([1e-4, 1e-4, 1e-5, 1e-4, 1e-4, 1e-6, 1e-6])
148x0 = true_params + perturbation * jax.random.normal(key2, shape=(7,))
149
150param_names = ["inclo", "nodeo", "ecco", "argpo", "mo", "no_kozai", "bstar"]
151
152# Parameter scaling: normalize by perturbation magnitude so that a
153# unit step in the scaled space corresponds to a perturbation-sized
154# step in physical space. This keeps the BFGS line search stable.
155param_scale = perturbation
156
157
158def scaled_loss(x_scaled, param_scale, gravity, epoch, jdsatepoch,
159 jdsatepochF, times, r_obs, sigma):
160 """Loss in normalized parameter space."""
161 params = x_scaled * param_scale
162 return loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF,
163 times, r_obs, sigma)
164
165
166x0_scaled = x0 / param_scale
167
168result = jax_minimize(
169 scaled_loss, x0_scaled,
170 args=(param_scale, gravity, epoch, jdsatepoch, jdsatepochF,
171 times, r_obs, sigma),
172 method="BFGS",
173)
174
175params_fit = result.x * param_scale
176print(f"\nOptimization finished: fun_val = {float(result.fun):.4f}, "
177 f"nit = {int(result.nit)}")
178print()
179print("Parameter Initial Optimized True")
180print("-" * 70)
181for name, xi, fi, ti in zip(param_names, x0, params_fit, true_params):
182 print(f"{name:12s} {float(xi):14.8f} {float(fi):14.8f} {float(ti):14.8f}")
183
184
185# %%
186# 5. Fisher Information & Parameter Uncertainties
187# -------------------------------------------------
188#
189# With the Jacobian of the forward model we estimate parameter
190# uncertainties via the Fisher information matrix:
191#
192# .. math::
193#
194# \mathbf{F} = \frac{1}{\sigma^2}\,\mathbf{J}^\top\mathbf{J},
195# \qquad
196# \mathrm{Cov}(\boldsymbol{\theta}) \approx \mathbf{F}^{-1}
197#
198# The 1-σ uncertainties are
199# :math:`\sqrt{\mathrm{diag}(\mathbf{F}^{-1})}`.
200
201jacobian_fn = jax.jit(jax.jacobian(predict_positions))
202J_full = jacobian_fn(params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
203
204n_times = times.shape[0]
205J = J_full.reshape(n_times * 3, 7)
206
207F = J.T @ J / sigma**2
208cov = jnp.linalg.inv(F)
209uncertainties = jnp.sqrt(jnp.diag(cov))
210
211print("\nParameter Best Fit 1-sigma Uncertainty")
212print("-" * 55)
213for name, fi, ui in zip(param_names, params_fit, uncertainties):
214 print(f"{name:12s} {float(fi):14.8f} +/- {float(ui):.2e}")
215
216
217# %%
218# 6. Results
219# ----------
220#
221# Visualize the position residuals and the parameter correlation matrix.
222
223try:
224 import matplotlib.pyplot as plt
225
226 r_fit = predict_positions(
227 params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
228 residuals = r_fit - r_obs
229
230 fig, axes = plt.subplots(2, 1, figsize=(10, 8))
231
232 # Plot 1: Position residuals
233 ax = axes[0]
234 for i, label in enumerate(["X", "Y", "Z"]):
235 ax.plot(times, residuals[:, i], ".", label=label, markersize=4)
236 ax.axhline(sigma, color="gray", ls="--", alpha=0.5, label=f"$\\pm${sigma} km")
237 ax.axhline(-sigma, color="gray", ls="--", alpha=0.5)
238 ax.set_xlabel("Time since epoch (min)")
239 ax.set_ylabel("Residual (km)")
240 ax.set_title("Position Residuals (fit - observed)")
241 ax.legend()
242 ax.grid(True, alpha=0.3)
243
244 # Plot 2: Correlation matrix
245 ax = axes[1]
246 std = jnp.sqrt(jnp.diag(cov))
247 corr = cov / jnp.outer(std, std)
248 im = ax.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
249 ax.set_xticks(range(7))
250 ax.set_xticklabels(param_names, rotation=45, ha="right")
251 ax.set_yticks(range(7))
252 ax.set_yticklabels(param_names)
253 ax.set_title("Parameter Correlation Matrix")
254 fig.colorbar(im, ax=ax, shrink=0.8)
255
256 plt.tight_layout()
257 plt.show()
258
259except ImportError:
260 print("(matplotlib not available — skipping plots)")
261
262
263# %%
264# Summary
265# -------
266#
267# This example demonstrated a complete orbit determination workflow:
268#
269# 1. **Synthetic data** — propagated a known TLE and added 1 km noise
270# 2. **Forward model** — ``sgp4init`` + ``propagate`` map orbital
271# parameters to positions
272# 3. **BFGS optimization** — JAX automatic differentiation provides
273# exact gradients through the entire SGP4 computation
274# 4. **Uncertainty estimation** — the Fisher information matrix from
275# the Jacobian gives 1-σ parameter uncertainties and correlations
276#
277# The key advantage of sgp4jax is **differentiability**: gradients,
278# Jacobians, and Hessians are computed automatically and efficiently,
279# enabling gradient-based optimization and rigorous uncertainty
280# quantification without finite differences.