Examples

Orbit Fitting

This example demonstrates how to use sgp4jax’s differentiable SGP4 propagator to fit orbital elements to noisy position observations.

Because sgp4jax is built on JAX, we get automatic differentiation for free — enabling gradient-based optimization and analytic uncertainty estimation.

Workflow:

  1. Generate synthetic observations from a known TLE with Gaussian noise

  2. Define a forward model mapping orbital parameters → predicted positions

  3. Fit 7 orbital parameters using JAX’s built-in BFGS optimizer

  4. Estimate parameter uncertainties via Fisher information

  1#!/usr/bin/env python
  2"""Orbit Fitting with sgp4jax
  3==============================
  4
  5This example demonstrates how to use sgp4jax's **differentiable SGP4
  6propagator** to fit orbital elements to noisy position observations.
  7
  8Because sgp4jax is built on JAX, we get automatic differentiation for
  9free — enabling gradient-based optimization and analytic uncertainty
 10estimation.
 11
 12Workflow:
 13
 141. Generate synthetic observations from a known TLE with Gaussian noise
 152. Define a forward model mapping orbital parameters → predicted positions
 163. Fit 7 orbital parameters using JAX's built-in BFGS optimizer
 174. Estimate parameter uncertainties via Fisher information
 18"""
 19
 20# %%
 21# Imports
 22# -------
 23
 24import jax
 25import jax.numpy as jnp
 26from jax.scipy.optimize import minimize as jax_minimize
 27
 28import sgp4jax
 29from sgp4jax import WGS72, tle_to_satrec, propagate
 30from sgp4jax._sgp4init import sgp4init
 31
 32
 33# %%
 34# 1. Generate Synthetic Observations
 35# -----------------------------------
 36#
 37# We start from a known ISS TLE as ground truth, propagate it to 50 time
 38# points over one day (1440 minutes), and add Gaussian noise with σ = 1 km
 39# to the positions.  This noise level is realistic for ground-based radar
 40# tracking.
 41
 42line1 = "1 25544U 98067A   24045.51782528  .00016717  00000-0  10270-3 0  9006"
 43line2 = "2 25544  51.6400  10.2827 0003856 197.0300 163.0590 15.49560044439368"
 44
 45sat_true = tle_to_satrec(line1, line2)
 46
 47times = jnp.linspace(0.0, 1440.0, 50)
 48
 49r_true, v_true, errs = jax.vmap(propagate, (None, 0))(sat_true, times)
 50print(f"Propagated {len(times)} time steps, positions shape: {r_true.shape}")
 51
 52sigma = 1.0  # km
 53key = jax.random.PRNGKey(42)
 54noise = sigma * jax.random.normal(key, shape=r_true.shape)
 55r_obs = r_true + noise
 56
 57print(f"RMS noise: {jnp.sqrt(jnp.mean(noise**2)):.3f} km")
 58
 59
 60# %%
 61# 2. The Forward Model
 62# ---------------------
 63#
 64# We define a function that takes 7 orbital parameters, builds a ``SatRec``
 65# via ``sgp4init``, and propagates to all observation times.
 66#
 67# ========  ============  ================================
 68# Index     Parameter     Description
 69# ========  ============  ================================
 70# 0         ``inclo``     Inclination (rad)
 71# 1         ``nodeo``     Right ascension of ascending node (rad)
 72# 2         ``ecco``      Eccentricity
 73# 3         ``argpo``     Argument of perigee (rad)
 74# 4         ``mo``        Mean anomaly (rad)
 75# 5         ``no_kozai``  Mean motion (rad/min), Kozai
 76# 6         ``bstar``     Drag coefficient (B*)
 77# ========  ============  ================================
 78
 79def predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times):
 80    """Forward model: orbital parameters -> predicted positions."""
 81    inclo = params[0]
 82    nodeo = params[1]
 83    ecco = params[2]
 84    argpo = params[3]
 85    mo = params[4]
 86    no_kozai = params[5]
 87    bstar = params[6]
 88
 89    sat = sgp4init(
 90        gravity, epoch, bstar,
 91        0.0, 0.0,  # ndot, nddot (fixed)
 92        ecco, argpo, inclo, mo, no_kozai, nodeo,
 93        jdsatepoch, jdsatepochF,
 94    )
 95    r, v, err = jax.vmap(propagate, (None, 0))(sat, times)
 96    return r  # (n_times, 3)
 97
 98
 99# %%
100# 3. Loss Function
101# -----------------
102#
103# Weighted sum of squared residuals:
104#
105# .. math::
106#
107#    \mathcal{L}(\boldsymbol{\theta})
108#    = \frac{1}{2\sigma^2}
109#      \sum_{i=1}^{N}
110#      \|\mathbf{r}_\text{pred}(t_i;\boldsymbol{\theta})
111#        - \mathbf{r}_\text{obs}(t_i)\|^2
112
113def loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF, times, r_obs, sigma):
114    """Weighted sum of squared residuals."""
115    r_pred = predict_positions(params, gravity, epoch, jdsatepoch, jdsatepochF, times)
116    residuals = r_pred - r_obs
117    return 0.5 * jnp.sum(residuals**2) / sigma**2
118
119
120# %%
121# 4. Initial Guess & Optimization
122# --------------------------------
123#
124# We start from a slightly perturbed version of the true parameters and
125# use JAX's built-in BFGS optimizer (``jax.scipy.optimize.minimize``).
126#
127# Parameter scaling is critical: we normalize by the expected perturbation
128# size so that the BFGS initial Hessian approximation (identity) produces
129# conservative step sizes across all parameters.
130
131true_params = jnp.array([
132    sat_true.inclo,
133    sat_true.nodeo,
134    sat_true.ecco,
135    sat_true.argpo,
136    sat_true.mo,
137    sat_true.no_kozai,
138    sat_true.bstar,
139])
140
141gravity = WGS72
142epoch = float(sat_true.jdsatepoch) + float(sat_true.jdsatepochF) - 2433281.5
143jdsatepoch = float(sat_true.jdsatepoch)
144jdsatepochF = float(sat_true.jdsatepochF)
145
146key2 = jax.random.PRNGKey(123)
147perturbation = jnp.array([1e-4, 1e-4, 1e-5, 1e-4, 1e-4, 1e-6, 1e-6])
148x0 = true_params + perturbation * jax.random.normal(key2, shape=(7,))
149
150param_names = ["inclo", "nodeo", "ecco", "argpo", "mo", "no_kozai", "bstar"]
151
152# Parameter scaling: normalize by perturbation magnitude so that a
153# unit step in the scaled space corresponds to a perturbation-sized
154# step in physical space.  This keeps the BFGS line search stable.
155param_scale = perturbation
156
157
158def scaled_loss(x_scaled, param_scale, gravity, epoch, jdsatepoch,
159                jdsatepochF, times, r_obs, sigma):
160    """Loss in normalized parameter space."""
161    params = x_scaled * param_scale
162    return loss_fn(params, gravity, epoch, jdsatepoch, jdsatepochF,
163                   times, r_obs, sigma)
164
165
166x0_scaled = x0 / param_scale
167
168result = jax_minimize(
169    scaled_loss, x0_scaled,
170    args=(param_scale, gravity, epoch, jdsatepoch, jdsatepochF,
171          times, r_obs, sigma),
172    method="BFGS",
173)
174
175params_fit = result.x * param_scale
176print(f"\nOptimization finished: fun_val = {float(result.fun):.4f}, "
177      f"nit = {int(result.nit)}")
178print()
179print("Parameter        Initial          Optimized        True")
180print("-" * 70)
181for name, xi, fi, ti in zip(param_names, x0, params_fit, true_params):
182    print(f"{name:12s}  {float(xi):14.8f}  {float(fi):14.8f}  {float(ti):14.8f}")
183
184
185# %%
186# 5. Fisher Information & Parameter Uncertainties
187# -------------------------------------------------
188#
189# With the Jacobian of the forward model we estimate parameter
190# uncertainties via the Fisher information matrix:
191#
192# .. math::
193#
194#    \mathbf{F} = \frac{1}{\sigma^2}\,\mathbf{J}^\top\mathbf{J},
195#    \qquad
196#    \mathrm{Cov}(\boldsymbol{\theta}) \approx \mathbf{F}^{-1}
197#
198# The 1-σ uncertainties are
199# :math:`\sqrt{\mathrm{diag}(\mathbf{F}^{-1})}`.
200
201jacobian_fn = jax.jit(jax.jacobian(predict_positions))
202J_full = jacobian_fn(params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
203
204n_times = times.shape[0]
205J = J_full.reshape(n_times * 3, 7)
206
207F = J.T @ J / sigma**2
208cov = jnp.linalg.inv(F)
209uncertainties = jnp.sqrt(jnp.diag(cov))
210
211print("\nParameter     Best Fit          1-sigma Uncertainty")
212print("-" * 55)
213for name, fi, ui in zip(param_names, params_fit, uncertainties):
214    print(f"{name:12s}  {float(fi):14.8f}  +/- {float(ui):.2e}")
215
216
217# %%
218# 6. Results
219# ----------
220#
221# Visualize the position residuals and the parameter correlation matrix.
222
223try:
224    import matplotlib.pyplot as plt
225
226    r_fit = predict_positions(
227        params_fit, gravity, epoch, jdsatepoch, jdsatepochF, times)
228    residuals = r_fit - r_obs
229
230    fig, axes = plt.subplots(2, 1, figsize=(10, 8))
231
232    # Plot 1: Position residuals
233    ax = axes[0]
234    for i, label in enumerate(["X", "Y", "Z"]):
235        ax.plot(times, residuals[:, i], ".", label=label, markersize=4)
236    ax.axhline(sigma, color="gray", ls="--", alpha=0.5, label=f"$\\pm${sigma} km")
237    ax.axhline(-sigma, color="gray", ls="--", alpha=0.5)
238    ax.set_xlabel("Time since epoch (min)")
239    ax.set_ylabel("Residual (km)")
240    ax.set_title("Position Residuals (fit - observed)")
241    ax.legend()
242    ax.grid(True, alpha=0.3)
243
244    # Plot 2: Correlation matrix
245    ax = axes[1]
246    std = jnp.sqrt(jnp.diag(cov))
247    corr = cov / jnp.outer(std, std)
248    im = ax.imshow(corr, cmap="RdBu_r", vmin=-1, vmax=1)
249    ax.set_xticks(range(7))
250    ax.set_xticklabels(param_names, rotation=45, ha="right")
251    ax.set_yticks(range(7))
252    ax.set_yticklabels(param_names)
253    ax.set_title("Parameter Correlation Matrix")
254    fig.colorbar(im, ax=ax, shrink=0.8)
255
256    plt.tight_layout()
257    plt.show()
258
259except ImportError:
260    print("(matplotlib not available — skipping plots)")
261
262
263# %%
264# Summary
265# -------
266#
267# This example demonstrated a complete orbit determination workflow:
268#
269# 1. **Synthetic data** — propagated a known TLE and added 1 km noise
270# 2. **Forward model** — ``sgp4init`` + ``propagate`` map orbital
271#    parameters to positions
272# 3. **BFGS optimization** — JAX automatic differentiation provides
273#    exact gradients through the entire SGP4 computation
274# 4. **Uncertainty estimation** — the Fisher information matrix from
275#    the Jacobian gives 1-σ parameter uncertainties and correlations
276#
277# The key advantage of sgp4jax is **differentiability**: gradients,
278# Jacobians, and Hessians are computed automatically and efficiently,
279# enabling gradient-based optimization and rigorous uncertainty
280# quantification without finite differences.