Getting Started

Basic propagation

Parse a TLE and propagate to a time offset (minutes from epoch):

import jax.numpy as jnp
import sgp4jax

line1 = "1 25544U 98067A   20045.18587073  .00000950  00000-0  25302-4 0  9990"
line2 = "2 25544  51.6443 242.0161 0004397 264.6060 207.3845 15.49165514212791"

sat = sgp4jax.tle_to_satrec(line1, line2)
r, v, error = sgp4jax.propagate(sat, jnp.array(0.0))
print(f"Position (km): {r}")
print(f"Velocity (km/s): {v}")

Julian Date propagation

Use propagate_jd() with a split Julian Date:

import jax.numpy as jnp

jd = jnp.array(sat.jdsatepoch)
fr = jnp.array(sat.jdsatepochF + 0.5)  # 12 hours later
r, v, error = sgp4jax.propagate_jd(sat, jd, fr)

JIT compilation

The propagation function is already JIT-compiled by default. You can also explicitly JIT-compile:

import jax

jitted_propagate = jax.jit(sgp4jax.propagate)
r, v, error = jitted_propagate(sat, jnp.array(100.0))

Batch propagation with vmap

Propagate a single satellite over many time steps at once:

times = jnp.linspace(0, 1440, 1000)  # one day, 1000 steps
batched = jax.vmap(sgp4jax.propagate, in_axes=(None, 0))
r_batch, v_batch, err_batch = batched(sat, times)
# r_batch.shape == (1000, 3)

Gradients

Compute gradients of any scalar function of position/velocity:

def loss(t):
    r, v, err = sgp4jax.propagate(sat, t)
    return jnp.sum(r ** 2)

grad_fn = jax.grad(loss)
g = grad_fn(jnp.array(100.0))

GCRF output

SGP4 outputs position and velocity in the TEME (True Equator Mean Equinox) frame. To get GCRF (Geocentric Celestial Reference Frame, ≈ICRS) output, use the GCRF convenience functions:

# Propagate directly to GCRF
r_gcrf, v_gcrf, error = sgp4jax.propagate_gcrf(sat, jnp.array(100.0))
print(f"Position (GCRF, km): {r_gcrf}")
print(f"Velocity (GCRF, km/s): {v_gcrf}")

# Or use Julian Date
jd = jnp.array(sat.jdsatepoch)
fr = jnp.array(sat.jdsatepochF + 0.5)
r_gcrf, v_gcrf, error = sgp4jax.propagate_jd_gcrf(sat, jd, fr)

You can also apply the transform separately with teme_to_gcrf():

r_teme, v_teme, error = sgp4jax.propagate(sat, jnp.array(100.0))
jd = jnp.array(sat.jdsatepoch)
fr = jnp.array(sat.jdsatepochF + 100.0 / 1440.0)
r_gcrf, v_gcrf = sgp4jax.teme_to_gcrf(r_teme, v_teme, jd, fr)

Gravity models

Three gravity models are available:

  • sgp4jax.WGS72 (default)

  • sgp4jax.WGS84

  • sgp4jax.WGS72OLD

Pass a different model to tle_to_satrec():

sat = sgp4jax.tle_to_satrec(line1, line2, gravity=sgp4jax.WGS84)

Batch TLE parsing

Parse multiple TLEs at once with tles_to_satrec(), which returns a batched SatRec ready for jax.vmap:

import jax
import jax.numpy as jnp
import sgp4jax

tles = [
    ("1 25544U 98067A   20045.18587073  .00000950  00000-0  25302-4 0  9990",
     "2 25544  51.6443 242.0161 0004397 264.6060 207.3845 15.49165514212791"),
    ("1 00005U 58002B   20045.93498537  .00000023  00000-0  24901-3 0  9999",
     "2 00005  34.2513 243.5765 1847090 326.4186  22.2640 10.84386407185708"),
]

sats = sgp4jax.tles_to_satrec(tles)

# Propagate all satellites at once
batched = jax.vmap(sgp4jax.propagate, in_axes=(0, None))
r, v, err = batched(sats, jnp.array(0.0))
# r.shape == (2, 3)

Convenience batch functions

gcrf_positions() propagates a single satellite to many Julian dates and returns GCRF positions/velocities:

sat = sgp4jax.tle_to_satrec(tles[0][0], tles[0][1])
times_jd = jnp.linspace(2458900.5, 2458901.5, 100)
r_gcrf, v_gcrf = sgp4jax.gcrf_positions(sat, times_jd)
# r_gcrf.shape == (100, 3)

gcrf_positions_multi() propagates multiple satellites to many Julian dates:

sats = sgp4jax.tles_to_satrec(tles)
r_gcrf, v_gcrf = sgp4jax.gcrf_positions_multi(sats, times_jd)
# r_gcrf.shape == (2, 100, 3)