# Copyright (c) 2026 CNES.
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""Orbit interpolation."""
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from collections.abc import Iterator
from .type_hints import (
NDArray1DFloat64,
NDArray1DTimeDelta64,
NDArray2DBool,
NDArray2DFloat64,
)
from . import core
from .core.geometry.geographic import Box, Coordinates, LineString, Spheroid
from .core.geometry.geographic.algorithms import (
Strategy,
curvilinear_distance,
for_each_point_covered_by,
intersection,
)
#: Minimum number of points required to process a pass.
_MIN_POINTS = 5
#: Minimum points to compute satellite equator position.
_MIN_EQUATOR_POINTS = 2
#: Latitude threshold to consider that the satellite is at the equator.
_EQUATOR_LAT_THRESHOLD = 40.0
def interpolate(
lon: NDArray1DFloat64,
lat: NDArray1DFloat64,
xp: NDArray1DFloat64,
xi: NDArray1DFloat64,
height: float = 0.0,
coordinates: Coordinates | None = None,
half_window_size: int = 3,
) -> tuple[NDArray1DFloat64, NDArray1DFloat64]:
"""Interpolate the given orbit at the given coordinates.
Args:
lon: Longitudes (in degrees).
lat: Latitudes (in degrees).
xp: The x-coordinates at which the orbit is defined.
height: Height of the satellite above the Earth's surface (in meters).
xi: The x-coordinates at which to evaluate the interpolated values.
coordinates: The geographic coordinates system used to convert the
coordinates from geodetic to ECEF and vice versa. If None, a
WGS-84 coordinate system is used.
half_window_size: Half size of the window used to interpolate the
orbit.
Returns:
The interpolated longitudes and latitudes.
"""
coordinates = coordinates or Coordinates()
spheroid = coordinates.spheroid
mz = spheroid.semi_major_axis / spheroid.semi_minor_axis()
x, y, z = coordinates.lla_to_ecef(
lon,
lat,
np.full_like(lon, height),
)
r = np.sqrt(x * x + y * y + z * z * mz * mz)
x_axis = core.Axis(xp - xp[0], 1e-6)
xi = xi - xp[0]
config = (
core.config.windowed.Univariate.c_spline()
.with_half_window_size(half_window_size)
.with_boundary_mode(
core.config.windowed.BoundaryConfig.shrink(),
)
)
x = core.univariate( # type: ignore[assignment]
core.Grid(x_axis, x),
xi, # type: ignore[arg-type]
config=config,
)
y = core.univariate( # type: ignore[assignment]
core.Grid(x_axis, y),
xi, # type: ignore[arg-type]
config=config,
)
z = core.univariate( # type: ignore[assignment]
core.Grid(x_axis, z),
xi, # type: ignore[arg-type]
config=config,
)
r = core.univariate(
core.Grid(x_axis, r),
xi, # type: ignore[arg-type]
config=config,
)
r /= np.sqrt(x * x + y * y + z * z)
x *= r
y *= r
z *= r * (1 / mz)
lon, lat, _ = coordinates.ecef_to_lla(
x, # type: ignore[arg-type]
y, # type: ignore[arg-type]
z, # type: ignore[arg-type]
)
return lon, lat
def _rearrange_orbit(
cycle_duration: np.timedelta64,
lon: NDArray1DFloat64,
lat: NDArray1DFloat64,
time: NDArray1DTimeDelta64,
) -> tuple[NDArray1DFloat64, NDArray1DFloat64, NDArray1DTimeDelta64]:
"""Rearrange orbit to start from pass 1.
Detect the beginning of pass 1 in the ephemeris and reorder the data
accordingly. By definition, pass 1 starts at the first passage at
southernmost latitude.
Args:
cycle_duration: Cycle time in seconds.
lon: Longitudes (in degrees).
lat: Latitudes (in degrees).
time: Time since the beginning of the orbit.
Returns:
The orbit rearranged starting from pass 1.
"""
dy = np.roll(lat, 1) - lat
indexes = np.where((dy < 0) & (np.roll(dy, 1) >= 0))[0]
# If the orbit is already starting from pass 1, nothing to do
if indexes[0] < int(indexes.mean()):
return lon, lat, time
# Shift coordinates, so that the first point of the orbit is the beginning
# of pass 1
shift = indexes[-1]
lon = np.hstack([lon[shift:], lon[:shift]])
lat = np.hstack([lat[shift:], lat[:shift]])
time = np.hstack([time[shift:], time[:shift]])
time = (time - time[0]) % cycle_duration
if np.any(time < np.timedelta64(0, "s")):
raise ValueError("Time is negative")
return lon, lat, time
def _calculate_pass_time(
lat: NDArray1DFloat64, time: NDArray1DTimeDelta64
) -> NDArray1DTimeDelta64:
"""Compute the initial time of each pass.
Args:
lat: Latitudes (in degrees)
time: Date of the latitudes (in seconds).
Returns:
Start date of half-orbits.
"""
dy = np.roll(lat, 1) - lat
indexes = np.where(
((dy < 0) & (np.roll(dy, 1) >= 0)) | ((dy > 0) & (np.roll(dy, 1) <= 0))
)[0]
# The duration of the first pass is zero.
indexes[0] = 0
return time[indexes]
[docs]
@dataclasses.dataclass(frozen=True)
class Orbit:
"""Represent properties of the orbit.
Store and manage orbital parameters including position, timing, and
geodetic information.
Args:
height: Height of the satellite (in meters).
latitude: Latitudes (in degrees).
longitude: Longitudes (in degrees).
pass_time: Start date of half-orbits.
time: Time elapsed since the beginning of the orbit.
x_al: Along track distance (in meters).
wgs: World Geodetic System used.
"""
#: Height of the satellite (in meters).
height: float
#: Latitudes (in degrees).
latitude: NDArray1DFloat64
#: Longitudes (in degrees).
longitude: NDArray1DFloat64
#: Start date of half-orbits.
pass_time: NDArray1DTimeDelta64
#: Time elapsed since the beginning of the orbit.
time: NDArray1DTimeDelta64
#: Along track distance (in meters).
x_al: NDArray1DFloat64
#: Spheroid model used.
wgs: Spheroid | None
[docs]
def cycle_duration(self) -> np.timedelta64:
"""Get the cycle duration."""
return self.time[-1]
[docs]
def passes_per_cycle(self) -> int:
"""Get the number of passes per cycle.
Returns:
The number of passes in one complete cycle.
"""
return len(self.pass_time)
[docs]
def orbit_duration(self) -> np.timedelta64:
"""Get the orbit duration.
Returns:
The duration of one complete orbit.
"""
duration = self.cycle_duration().astype(
"timedelta64[us]"
) / np.timedelta64(int(self.passes_per_cycle() // 2), "us")
return np.timedelta64(int(duration), "us")
[docs]
def curvilinear_distance(self) -> NDArray1DFloat64:
"""Get the curvilinear distance.
Returns:
The curvilinear distance along the orbit.
"""
ls = LineString(
self.longitude,
self.latitude,
)
return curvilinear_distance(
ls,
spheroid=self.wgs,
strategy=Strategy.THOMAS,
)
[docs]
def pass_duration(self, number: int) -> np.timedelta64:
"""Get the duration of a given pass.
Args:
number: track number (must be in [1, passes_per_cycle()])
Returns:
numpy.datetime64: track duration
"""
passes_per_cycle = self.passes_per_cycle()
if number < 1 or number > passes_per_cycle:
raise ValueError(f"number must be in [1, {passes_per_cycle}]")
if number == passes_per_cycle:
return (
self.time[-1]
- self.pass_time[-1]
+ self.time[1]
- self.time[0]
)
return self.pass_time[number] - self.pass_time[number - 1]
[docs]
def decode_absolute_pass_number(self, number: int) -> tuple[int, int]:
"""Calculate cycle and pass numbers from an absolute pass number.
Convert an absolute pass number into its corresponding cycle and pass
number components.
Args:
number: Absolute pass number.
Returns:
A tuple containing the cycle number and pass number.
"""
number -= 1
return (
int(number / self.passes_per_cycle()) + 1,
(number % self.passes_per_cycle()) + 1,
)
[docs]
def encode_absolute_pass_number(
self, cycle_number: int, pass_number: int
) -> int:
"""Calculate the absolute pass number for a given half-orbit.
Args:
cycle_number (int): Cycle number
pass_number (int): Pass number
Returns:
int: Absolute pass number
"""
passes_per_cycle = self.passes_per_cycle()
if not 1 <= pass_number <= passes_per_cycle:
raise ValueError(f"pass_number must be in [1, {passes_per_cycle}")
return (cycle_number - 1) * self.passes_per_cycle() + pass_number
[docs]
def delta_t(self) -> np.timedelta64:
"""Return the average time difference between two measurements.
Calculate the mean time interval between consecutive measurements.
Returns:
Average time difference between measurements.
"""
return np.diff(self.time).mean()
[docs]
def iterate(
self,
first_date: np.datetime64 | None = None,
last_date: np.datetime64 | None = None,
absolute_pass_number: int = 1,
) -> Iterator[tuple[int, int, np.datetime64]]:
"""Obtain all half-orbits within the defined time interval.
Args:
first_date: First date of the period to be considered.
Defaults to the current date.
last_date: Last date of the period to be considered.
Defaults to the current date plus the orbit duration.
absolute_pass_number (int, optional): Absolute number of the first
pass to be returned.
Returns:
iterator: An iterator for all passes in the interval pointing to
the cycle number, pass number and start date of the half-orbit.
"""
date = first_date or np.datetime64("now")
last_date = last_date or date + self.cycle_duration()
while date <= last_date:
cycle_number, pass_number = self.decode_absolute_pass_number(
absolute_pass_number
)
yield cycle_number, pass_number, date
# Shift the date of the duration of the generated pass
date += self.pass_duration(pass_number)
# Update of the number of the next pass to be generated
absolute_pass_number += 1
return StopIteration # type: ignore[return-value]
[docs]
@dataclasses.dataclass(frozen=True)
class EquatorCoordinates:
"""Represent coordinates of the satellite at the equator.
Store the longitude and time when the satellite crosses the equator.
"""
#: Longitude
longitude: float
#: Product dataset name
time: np.datetime64
[docs]
@classmethod
def undefined(cls) -> EquatorCoordinates:
"""Create an undefined instance."""
return cls(np.nan, np.datetime64("NaT"))
[docs]
@dataclasses.dataclass(frozen=True)
class Pass:
"""Represent a pass of an orbit.
Store the properties of a single orbital pass including nadir coordinates,
timing, and along-track distance.
"""
#: Nadir longitude of the pass (degrees)
lon_nadir: NDArray1DFloat64
#: Nadir latitude of the pass (degrees)
lat_nadir: NDArray1DFloat64
#: Time of the pass
time: NDArray1DTimeDelta64
#: Along track distance of the pass (in meters)
x_al: NDArray1DFloat64
#: Coordinates of the satellite at the equator
equator_coordinates: EquatorCoordinates
[docs]
def __len__(self) -> int:
"""Get the number of points in the pass."""
return len(self.time)
[docs]
@dataclasses.dataclass(frozen=True)
class Swath(Pass):
"""Represent a swath of an orbital pass.
Extend the Pass class with additional swath-specific properties including
cross-track coordinates and distances.
"""
#: Longitude of the swath (degrees)
lon: NDArray2DFloat64
#: Latitude of the swath (degrees)
lat: NDArray2DFloat64
#: Across track distance of the pass (m)
x_ac: NDArray2DFloat64
[docs]
def mask(self, requirement_bounds: tuple[float, float]) -> NDArray2DBool:
"""Obtain a mask to set NaN values outside the mission requirements.
Args:
requirement_bounds (tuple): Limits of SWOT swath requirements:
absolute value of the minimum and maximum across track
distance.
Returns:
Mask set true, if the swath is outside the requirements of the
mission.
"""
valid = np.full_like(self.x_ac, 0, dtype=np.bool_)
valid[
(np.abs(self.x_ac) >= requirement_bounds[0])
& (np.abs(self.x_ac) <= requirement_bounds[1])
] = 1
along_track = np.full(self.lon_nadir.shape, 1, dtype=np.bool_)
return along_track[:, np.newaxis] * valid
[docs]
def insert_central_pixel(self) -> Swath:
"""Insert a central pixel dividing the swath in two.
Return a new swath with a central pixel added at the reference ground
track, effectively dividing the swath into two halves.
Returns:
A new Swath instance with the central pixel inserted.
"""
def _insert(
array: NDArray2DFloat64,
central_pixel: int,
fill_value: NDArray1DFloat64,
) -> NDArray2DFloat64:
"""Insert a central pixel in a given array."""
return np.c_[
array[:, :central_pixel],
fill_value[:, np.newaxis],
array[:, central_pixel:],
]
num_pixels = self.lon.shape[1] + 1
num_lines = self.lon.shape[0]
central_pixel = num_pixels // 2
return Swath(
self.lon_nadir,
self.lat_nadir,
self.time,
self.x_al,
self.equator_coordinates,
_insert(self.lon, central_pixel, self.lon_nadir),
_insert(self.lat, central_pixel, self.lat_nadir),
_insert(
self.x_ac,
central_pixel,
np.zeros(num_lines, dtype=self.x_ac.dtype),
),
)
def _equator_properties(
lon_nadir: NDArray1DFloat64,
lat_nadir: NDArray1DFloat64,
time: NDArray1DTimeDelta64,
) -> EquatorCoordinates:
"""Calculate the position of the satellite at the equator.
Determine where and when the satellite crosses the equator.
Args:
lon_nadir: Nadir longitudes (in degrees).
lat_nadir: Nadir latitudes (in degrees).
time: Time since the beginning of the orbit.
Returns:
The equator coordinates of the satellite.
"""
if lon_nadir.size < _MIN_EQUATOR_POINTS:
return EquatorCoordinates.undefined()
# Search the nearest point to the equator
i1 = (np.abs(lat_nadir)).argmin()
i0 = i1 - 1 if i1 > 0 else 1
if lat_nadir[i0] * lat_nadir[i1] > 0:
i0, i1 = (i1, i1 + 1) if i1 < lat_nadir.size - 1 else (i1 - 1, i1)
lon1 = lon_nadir[i0 : i1 + 1]
lat1 = lat_nadir[i0 : i1 + 1]
# Calculate the position of the satellite at the equator
points = intersection(
LineString(
lon1,
lat1,
),
LineString(
np.array([lon1[0] - 0.5, lon1[1] + 0.5]),
np.array(
[0, 0],
dtype="float64",
),
),
)
if len(points) == 0:
return EquatorCoordinates.undefined()
point = points[0]
# Calculate the time of the point on the equator
lon1 = np.insert(lon1, 1, point.lon)
lat1 = np.insert(lat1, 1, 0)
x_al = curvilinear_distance(
LineString(
lon1,
lat1,
),
strategy=Strategy.THOMAS,
spheroid=None,
)
# Pop the along track distance at the equator
x_eq = x_al[1]
x_al = np.delete(
x_al,
1,
)
return EquatorCoordinates(
point.lon,
np.interp(x_eq, x_al, time[i0 : i1 + 1].astype("i8")).astype(
time.dtype
),
)
[docs]
def calculate_orbit(
height: float,
lon_nadir: NDArray1DFloat64,
lat_nadir: NDArray1DFloat64,
time: NDArray1DTimeDelta64,
cycle_duration: np.timedelta64 | None = None,
along_track_resolution: float | None = None,
spheroid: Spheroid | None = None,
) -> Orbit:
"""Calculate the orbit at the given height.
Args:
height: Height of the orbit, in meters.
lon_nadir: Nadir longitude of the orbit (degrees).
lat_nadir: Nadir latitude of the orbit (degrees).
time: Time elapsed since the start of the orbit.
cycle_duration: Duration of the cycle.
along_track_resolution: Resolution of the along-track interpolation in
kilometers. Defaults to 2 kilometers.
spheroid: Spheroid to use for the calculations. Defaults to WGS84.
Returns:
Orbit object.
"""
coordinates = Coordinates(spheroid)
# If the first point of the given orbit starts at the equator, we need to
# skew this first pass.
if -_EQUATOR_LAT_THRESHOLD <= lat_nadir[0] <= _EQUATOR_LAT_THRESHOLD:
dy = np.roll(lat_nadir, 1) - lat_nadir
indexes = np.where(
((dy < 0) & (np.roll(dy, 1) >= 0))
| ((dy > 0) & (np.roll(dy, 1) <= 0))
)[0]
lat_nadir = lat_nadir[indexes[1:]]
lon_nadir = lon_nadir[indexes[1:]]
time = time[indexes[1:]]
lon_nadir = (lon_nadir + 180) % 360 - 180
time = time.astype("m8[ms]")
if np.mean(np.diff(time)) > np.timedelta64(500, "ms"):
time_hr = np.arange(
time[0], time[-1], np.timedelta64(500, "ms"), dtype="m8[ms]"
)
lon_nadir, lat_nadir = interpolate(
lon_nadir,
lat_nadir,
time.view("i8"),
time_hr.view("i8"),
height=height,
coordinates=coordinates,
half_window_size=50,
)
time = time_hr
if cycle_duration is not None:
indexes = np.where(time < cycle_duration)[0]
lon_nadir = lon_nadir[indexes]
lat_nadir = lat_nadir[indexes]
time = time[indexes]
del indexes
# Rearrange orbit starting from pass 1
lon_nadir, lat_nadir, time = _rearrange_orbit(
time[-1] + time[1] - time[0],
lon_nadir,
lat_nadir,
time,
)
# Calculates the along track distance (km)
distance = (
curvilinear_distance(
LineString(
lon_nadir,
lat_nadir,
),
spheroid,
Strategy.THOMAS,
)
* 1e-3
)
# Interpolate the final orbit according the given along track resolution
x_al = np.arange(
distance[0],
distance[-2],
along_track_resolution or 2,
dtype=distance.dtype,
)
lon_nadir, lat_nadir = interpolate(
lon_nadir[:-1],
lat_nadir[:-1],
distance[:-1],
x_al,
height=height,
coordinates=coordinates,
half_window_size=10,
)
time = np.interp(x_al, distance[:-1], time[:-1].astype("i8")).astype(
time.dtype
)
return Orbit(
height,
lat_nadir,
lon_nadir,
np.sort(_calculate_pass_time(lat_nadir, time)),
time,
x_al,
coordinates.spheroid, # type: ignore[arg-type]
)
[docs]
def calculate_pass(
pass_number: int,
orbit: Orbit,
*,
bbox: Box | None = None,
) -> Pass | None:
"""Get the properties of a swath of an half-orbit.
Args:
pass_number: Pass number
orbit: Orbit describing the pass to be calculated.
bbox: Bounding box of the pass. Defaults to the whole Earth.
Returns:
The properties of the pass.
"""
index = pass_number - 1
# Selected indexes corresponding to the current pass
if index == len(orbit.pass_time) - 1:
indexes = np.where(orbit.time >= orbit.pass_time[-1])[0]
else:
indexes = np.where(
(orbit.time >= orbit.pass_time[index])
& (orbit.time < orbit.pass_time[index + 1])
)[0]
if len(indexes) < _MIN_POINTS:
return None
lon_nadir = orbit.longitude[indexes]
lat_nadir = orbit.latitude[indexes]
time = orbit.time[indexes]
x_al = orbit.x_al[indexes]
# Selects the orbit in the defined box
if bbox is not None:
mask = for_each_point_covered_by(
LineString(
lon_nadir,
lat_nadir,
),
bbox,
)
if np.all(~mask):
return None
if np.any(mask):
lon_nadir = lon_nadir[mask]
lat_nadir = lat_nadir[mask]
time = time[mask]
x_al = x_al[mask]
equator_coordinates = _equator_properties(lon_nadir, lat_nadir, time)
return Pass(
lon_nadir,
lat_nadir,
time,
x_al,
equator_coordinates,
)
[docs]
def calculate_swath(
half_orbit: Pass,
*,
across_track_resolution: float | None = None,
along_track_resolution: float | None = None,
half_swath: float | None = None,
half_gap: float | None = None,
spheroid: Spheroid | None = None,
) -> Swath:
"""Get the properties of a swath of an half-orbit.
Args:
half_orbit: Half-orbit used to calculate the swath.
bbox: Bounding box of the pass. Defaults to the whole Earth.
across_track_resolution: Distance, in km, between two points across
track direction. Defaults to 2 km.
along_track_resolution: Distance, in km, between two points along track
direction. Defaults to 2 km.
half_swath: Distance, in km, between the nadir and the center of the
last pixel of the swath. Defaults to 70 km.
half_gap: Distance, in km, between the nadir and the center of the
first pixel of the swath. Defaults to 2 km.
spheroid: The spheroid to use for the calculation. Defaults to
``None``, which means the WGS-84 spheroid is used.
Returns:
The properties of the pass.
"""
across_track_resolution = across_track_resolution or 2.0
along_track_resolution = along_track_resolution or 2
half_swath = half_swath or 70.0
half_gap = half_gap or 2.0
# Compute across track distances from nadir
# Number of points in half of the swath
half_swath = int((half_swath - half_gap) / across_track_resolution) + 1
x_ac = (
np.arange(half_swath, dtype=float) * along_track_resolution + half_gap
)
x_ac = np.hstack((-np.flip(x_ac), x_ac)) * 1e3
x_ac = np.full((len(half_orbit), x_ac.size), x_ac)
lon, lat = core.geometry.satellite.calculate_swath(
half_orbit.lon_nadir,
half_orbit.lat_nadir,
across_track_resolution * 1e3,
half_gap * 1e3,
half_swath,
spheroid,
)
return Swath(
half_orbit.lon_nadir,
half_orbit.lat_nadir,
half_orbit.time,
half_orbit.x_al,
half_orbit.equator_coordinates,
lon,
lat,
x_ac,
)