# Copyright (c) 2025 CNES
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""
Orbit interpolation.
====================
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import dataclasses
import numpy
if TYPE_CHECKING:
    from collections.abc import Iterator
    from .typing import NDArray, NDArrayTimeDelta
from . import core, geodetic
def interpolate(
    lon: NDArray,
    lat: NDArray,
    xp: NDArray,
    xi: NDArray,
    height: float = 0.0,
    wgs: geodetic.Coordinates | None = None,
    half_window_size: int = 10,
) -> tuple[NDArray, NDArray]:
    """Interpolate the given orbit at the given coordinates.
    Args:
        lon: Longitudes (in degrees).
        lat: Latitudes (in degrees).
        xp: The x-coordinates at which the orbit is defined.
        height: Height of the satellite above the Earth's surface (in meters).
        xi: The x-coordinates at which to evaluate the interpolated values.
        wgs: The World Geodetic System used to convert the coordinates.
        half_window_size: Half size of the window used to interpolate the
            orbit.
    Returns:
        Tuple[NDArray, NDArray]: The interpolated longitudes and latitudes.
    """
    wgs = wgs or geodetic.Coordinates()
    mz = wgs.spheroid.semi_major_axis / wgs.spheroid.semi_minor_axis()
    x, y, z = wgs.lla_to_ecef(
        lon,  # type: ignore[arg-type]
        lat,  # type: ignore[arg-type]
        numpy.full_like(lon, height),  # type: ignore[arg-type]
    )
    r = numpy.sqrt(x * x + y * y + z * z * mz * mz)
    x_axis = core.Axis((xp - xp[0]).astype(numpy.float64), 1e-6, False)
    xi = (xi - xp[0]).astype(numpy.float64)
    x = core.interpolate1d(
        x_axis,
        x,
        xi,  # type: ignore[arg-type]
        half_window_size=half_window_size,
    )
    y = core.interpolate1d(
        x_axis,
        y,
        xi,  # type: ignore[arg-type]
        half_window_size=half_window_size,
    )
    z = core.interpolate1d(
        x_axis,
        z,
        xi,  # type: ignore[arg-type]
        half_window_size=half_window_size,
    )
    r = core.interpolate1d(
        x_axis,
        r,
        xi,  # type: ignore[arg-type]
        half_window_size=half_window_size,
    )
    r /= numpy.sqrt(x * x + y * y + z * z)
    x *= r
    y *= r
    z *= r * (1 / mz)
    lon, lat, _ = wgs.ecef_to_lla(x, y, z)
    return lon, lat
def _rearrange_orbit(
    cycle_duration: numpy.timedelta64,
    lon: NDArray,
    lat: NDArray,
    time: NDArrayTimeDelta,
) -> tuple[NDArray, NDArray, NDArrayTimeDelta]:
    """Rearrange orbit starting from pass 1.
    Detect the beginning of pass 1 in the ephemeris. By definition, it is
    the first passage at southernmost latitude.
    Args:
        cycle_duration: Cycle time in seconds.
        lon: Longitudes (in degrees).
        lat: Latitudes (in degrees).
        time: Time since the beginning of the orbit.
    Returns:
        The orbit rearranged starting from pass 1.
    """
    dy = numpy.roll(lat, 1) - lat
    indexes = numpy.where((dy < 0) & (numpy.roll(dy, 1) >= 0))[0]
    # If the orbit is already starting from pass 1, nothing to do
    if indexes[0] < int(indexes.mean()):
        return lon, lat, time
    # Shift coordinates, so that the first point of the orbit is the beginning
    # of pass 1
    shift = indexes[-1]
    lon = numpy.hstack([lon[shift:], lon[:shift]])
    lat = numpy.hstack([lat[shift:], lat[:shift]])
    time = numpy.hstack([time[shift:], time[:shift]])
    time = (time - time[0]) % cycle_duration
    if numpy.any(time < numpy.timedelta64(0, 's')):
        raise ValueError('Time is negative')
    return lon, lat, time
def _calculate_pass_time(lat: NDArray,
                         time: NDArrayTimeDelta) -> NDArrayTimeDelta:
    """Compute the initial time of each pass.
    Args:
        lat: Latitudes (in degrees)
        time: Date of the latitudes (in seconds).
    Returns:
        Start date of half-orbits.
    """
    dy = numpy.roll(lat, 1) - lat
    indexes = numpy.where(((dy < 0) & (numpy.roll(dy, 1) >= 0))
                          | ((dy > 0)
                             & (numpy.roll(dy, 1) <= 0)))[0]
    # The duration of the first pass is zero.
    indexes[0] = 0
    return time[indexes]
[docs]
@dataclasses.dataclass(frozen=True)
class Orbit:
    """Properties of the orbit.
    Args:
        height: Height of the satellite (in meters).
        latitude: Latitudes (in degrees).
        longitude: Longitudes (in degrees).
        pass_time: Start date of half-orbits.
        time: Time elapsed since the beginning of the orbit.
        x_al: Along track distance (in meters).
        wgs: World Geodetic System used.
    """
    #: Height of the satellite (in meters).
    height: float
    #: Latitudes (in degrees).
    latitude: NDArray
    #: Longitudes (in degrees).
    longitude: NDArray
    #: Start date of half-orbits.
    pass_time: NDArrayTimeDelta
    #: Time elapsed since the beginning of the orbit.
    time: NDArrayTimeDelta
    #: Along track distance (in meters).
    x_al: NDArray
    #: Spheroid model used.
    wgs: geodetic.Spheroid | None
[docs]
    def cycle_duration(self) -> numpy.timedelta64:
        """Get the cycle duration."""
        return self.time[-1] 
[docs]
    def passes_per_cycle(self) -> int:
        """Get the number of passes per cycle."""
        return len(self.pass_time) 
[docs]
    def orbit_duration(self) -> numpy.timedelta64:
        """Get the orbit duration."""
        duration = self.cycle_duration().astype(
            'timedelta64[us]') / numpy.timedelta64(
                int(self.passes_per_cycle() // 2), 'us')
        return numpy.timedelta64(int(duration), 'us') 
[docs]
    def curvilinear_distance(self) -> numpy.ndarray:
        """Get the curvilinear distance."""
        return geodetic.LineString(
            self.longitude,  # type: ignore[arg-type]
            self.latitude,  # type: ignore[arg-type]
        ).curvilinear_distance(strategy='thomas', wgs=self.wgs) 
[docs]
    def pass_duration(self, number: int) -> numpy.timedelta64:
        """Get the duration of a given pass.
        Args:
            number: track number (must be in [1, passes_per_cycle()])
        Returns:
            numpy.datetime64: track duration
        """
        passes_per_cycle = self.passes_per_cycle()
        if number < 1 or number > passes_per_cycle:
            raise ValueError(f'number must be in [1, {passes_per_cycle}]')
        if number == passes_per_cycle:
            return (self.time[-1] - self.pass_time[-1] + self.time[1] -
                    self.time[0])
        return self.pass_time[number] - self.pass_time[number - 1] 
[docs]
    def decode_absolute_pass_number(self, number: int) -> tuple[int, int]:
        """Calculate the cycle and pass number from a given absolute pass
        number.
        Args:
            number (int): absolute pass number
        Returns:
            tuple: cycle and pass number
        """
        number -= 1
        return (int(number / self.passes_per_cycle()) + 1,
                (number % self.passes_per_cycle()) + 1) 
[docs]
    def encode_absolute_pass_number(self, cycle_number: int,
                                    pass_number: int) -> int:
        """Calculate the absolute pass number for a given half-orbit.
        Args:
            cycle_number (int): Cycle number
            pass_number (int): Pass number
        Returns:
            int: Absolute pass number
        """
        passes_per_cycle = self.passes_per_cycle()
        if not 1 <= pass_number <= passes_per_cycle:
            raise ValueError(f'pass_number must be in [1, {passes_per_cycle}')
        return (cycle_number - 1) * self.passes_per_cycle() + pass_number 
[docs]
    def delta_t(self) -> numpy.timedelta64:
        """Returns the average time difference between two measurements.
        Returns:
            int: average time difference
        """
        return numpy.diff(self.time).mean() 
[docs]
    def iterate(
        self,
        first_date: numpy.datetime64 | None = None,
        last_date: numpy.datetime64 | None = None,
        absolute_pass_number: int = 1
    ) -> Iterator[tuple[int, int, numpy.datetime64]]:
        """Obtain all half-orbits within the defined time interval.
        Args:
            first_date: First date of the period to be considered.
                Defaults to the current date.
            last_date: Last date of the period to be considered.
                Defaults to the current date plus the orbit duration.
            absolute_pass_number (int, optional): Absolute number of the first
                pass to be returned.
        Returns:
            iterator: An iterator for all passes in the interval pointing to
            the cycle number, pass number and start date of the half-orbit.
        """
        date = first_date or numpy.datetime64('now')
        last_date = last_date or date + self.cycle_duration()
        while date <= last_date:
            cycle_number, pass_number = self.decode_absolute_pass_number(
                absolute_pass_number)
            yield cycle_number, pass_number, date
            # Shift the date of the duration of the generated pass
            date += self.pass_duration(pass_number)
            # Update of the number of the next pass to be generated
            absolute_pass_number += 1
        return StopIteration  # type: ignore[return-value] 
 
[docs]
@dataclasses.dataclass(frozen=True)
class EquatorCoordinates:
    """Coordinates of the satellite at the equator."""
    #: Longitude
    longitude: float
    #: Product dataset name
    time: numpy.datetime64
[docs]
    @classmethod
    def undefined(cls) -> EquatorCoordinates:
        """Create an undefined instance."""
        return cls(numpy.nan, numpy.datetime64('NaT')) 
 
[docs]
@dataclasses.dataclass(frozen=True)
class Pass:
    """Class representing a pass of an orbit."""
    #: Nadir longitude of the pass (degrees)
    lon_nadir: NDArray
    #: Nadir latitude of the pass (degrees)
    lat_nadir: NDArray
    #: Time of the pass
    time: NDArrayTimeDelta
    #: Along track distance of the pass (in meters)
    x_al: NDArray
    #: Coordinates of the satellite at the equator
    equator_coordinates: EquatorCoordinates
[docs]
    def __len__(self) -> int:
        """Get the number of points in the pass."""
        return len(self.time) 
 
[docs]
@dataclasses.dataclass(frozen=True)
class Swath(Pass):
    """Class representing a pass of an orbit."""
    #: Longitude of the swath (degrees)
    lon: NDArray
    #: Latitude of the swath (degrees)
    lat: NDArray
    #: Across track distance of the pass (m)
    x_ac: NDArray
[docs]
    def mask(self, requirement_bounds: tuple[float, float]) -> NDArray:
        """Obtain a mask to set NaN values outside the mission requirements.
        Args:
            requirement_bounds (tuple): Limits of SWOT swath requirements:
                absolute value of the minimum and maximum across track
                distance.
        Returns:
            Mask set true, if the swath is outside the requirements of the
            mission.
        """
        valid = numpy.full_like(self.x_ac, numpy.nan)
        valid[(numpy.abs(self.x_ac) >= requirement_bounds[0])
              & (numpy.abs(self.x_ac) <= requirement_bounds[1])] = 1
        along_track = numpy.full(self.lon_nadir.shape, 1, dtype=numpy.float64)
        return along_track[:, numpy.newaxis] * valid 
[docs]
    def insert_central_pixel(self) -> Swath:
        """Return a swath with a central pixel dividing the swath in two by the
        reference ground track."""
        def _insert(array: NDArray, central_pixel: int,
                    fill_value: NDArray) -> NDArray:
            """Insert a central pixel in a given array."""
            return numpy.c_[array[:, :central_pixel],
                            fill_value[:, numpy.newaxis],
                            array[:, central_pixel:]]
        num_pixels = self.lon.shape[1] + 1
        num_lines = self.lon.shape[0]
        central_pixel = num_pixels // 2
        return Swath(
            self.lon_nadir, self.lat_nadir, self.time, self.x_al,
            self.equator_coordinates,
            _insert(self.lon, central_pixel, self.lon_nadir),
            _insert(self.lat, central_pixel, self.lat_nadir),
            _insert(self.x_ac, central_pixel,
                    numpy.zeros(num_lines, dtype=self.x_ac.dtype))) 
 
def _equator_properties(lon_nadir: NDArray, lat_nadir: NDArray,
                        time: NDArrayTimeDelta) -> EquatorCoordinates:
    """Calculate the position of the satellite at the equator."""
    if lon_nadir.size < 2:
        return EquatorCoordinates.undefined()
    # Search the nearest point to the equator
    i1 = (numpy.abs(lat_nadir)).argmin()
    i0 = i1 - 1 if i1 > 0 else 1
    if lat_nadir[i0] * lat_nadir[i1] > 0:
        i0, i1 = (i1, i1 + 1) if i1 < lat_nadir.size - 1 else (i1 - 1, i1)
    lon1 = lon_nadir[i0:i1 + 1]
    lat1 = lat_nadir[i0:i1 + 1]
    # Calculate the position of the satellite at the equator
    intersection = geodetic.LineString(
        lon1,  # type: ignore[arg-type]
        lat1,  # type: ignore[arg-type]
    ).intersection(
        geodetic.LineString(
            numpy.array([lon1[0] - 0.5, lon1[1] + 0.5]),
            numpy.array(  # type: ignore[arg-type]
                [0, 0],
                dtype='float64',
            ),
        ))
    if len(intersection) == 0:
        return EquatorCoordinates.undefined()
    point = intersection[0]
    # Calculate the time of the point on the equator
    lon1 = numpy.insert(lon1, 1, point.lon)
    lat1 = numpy.insert(lat1, 1, 0)
    x_al = geodetic.LineString(
        lon1,  # type: ignore[arg-type]
        lat1,  # type: ignore[arg-type]
    ).curvilinear_distance(strategy='thomas')
    # Pop the along track distance at the equator
    x_eq = x_al[1]
    x_al = numpy.delete(  # type: ignore[assignment]
        x_al,
        1,
    )
    return EquatorCoordinates(
        point.lon,
        numpy.interp(  # type: ignore[arg-type]
            x_eq, x_al, time[i0:i1 + 1].astype('i8')).astype(time.dtype),
    )
[docs]
def calculate_orbit(
    height: float,
    lon_nadir: NDArray,
    lat_nadir: NDArray,
    time: NDArrayTimeDelta,
    cycle_duration: numpy.timedelta64 | None = None,
    along_track_resolution: float | None = None,
    spheroid: geodetic.Spheroid | None = None,
) -> Orbit:
    """Calculate the orbit at the given height.
    Args:
        height: Height of the orbit, in meters.
        lon_nadir: Nadir longitude of the orbit (degrees).
        lat_nadir: Nadir latitude of the orbit (degrees).
        time: Time elapsed since the start of the orbit.
        cycle_duration: Duration of the cycle.
        along_track_resolution: Resolution of the along-track interpolation in
            kilometers. Defaults to 2 kilometers.
        spheroid: Spheroid to use for the calculations. Defaults to WGS84.
    Returns:
        Orbit object.
    """
    wgs = geodetic.Coordinates(spheroid)
    # If the first point of the given orbit starts at the equator, we need to
    # skew this first pass.
    if -40 <= lat_nadir[0] <= 40:
        dy = numpy.roll(lat_nadir, 1) - lat_nadir
        indexes = numpy.where(((dy < 0) & (numpy.roll(dy, 1) >= 0))
                              | ((dy > 0)
                                 & (numpy.roll(dy, 1) <= 0)))[0]
        lat_nadir = lat_nadir[indexes[1:]]
        lon_nadir = lon_nadir[indexes[1:]]
        time = time[indexes[1:]]
    lon_nadir = geodetic.normalize_longitudes(lon_nadir)
    time = time.astype('m8[ns]')
    if numpy.mean(numpy.diff(time)) > numpy.timedelta64(500, 'ms'):
        time_hr = numpy.arange(time[0],
                               time[-1],
                               numpy.timedelta64(500, 'ms'),
                               dtype=time.dtype)
        lon_nadir, lat_nadir = interpolate(lon_nadir,
                                           lat_nadir,
                                           time.astype('i8'),
                                           time_hr.astype('i8'),
                                           height=height,
                                           wgs=wgs,
                                           half_window_size=50)
        time = time_hr
    if cycle_duration is not None:
        indexes = numpy.where(time < cycle_duration)[0]
        lon_nadir = lon_nadir[indexes]
        lat_nadir = lat_nadir[indexes]
        time = time[indexes]
        del indexes
    # Rearrange orbit starting from pass 1
    lon_nadir, lat_nadir, time = _rearrange_orbit(
        time[-1] + time[1] - time[0],
        lon_nadir,
        lat_nadir,
        time,
    )
    # Calculates the along track distance (km)
    distance = geodetic.LineString(
        lon_nadir,  # type: ignore[arg-type]
        lat_nadir,  # type: ignore[arg-type]
    ).curvilinear_distance(strategy='thomas', wgs=spheroid) * 1e-3
    # Interpolate the final orbit according the given along track resolution
    x_al = numpy.arange(distance[0],
                        distance[-2],
                        along_track_resolution or 2,
                        dtype=distance.dtype)
    lon_nadir, lat_nadir = interpolate(lon_nadir[:-1],
                                       lat_nadir[:-1],
                                       distance[:-1],
                                       x_al,
                                       height=height,
                                       wgs=wgs,
                                       half_window_size=10)
    time = numpy.interp(x_al, distance[:-1],
                        time[:-1].astype('i8')).astype(time.dtype)
    return Orbit(
        height,
        lat_nadir,
        lon_nadir,
        numpy.sort(_calculate_pass_time(lat_nadir, time)),
        time,
        x_al,
        wgs.spheroid,  # type: ignore[arg-type]
    ) 
[docs]
def calculate_pass(
    pass_number: int,
    orbit: Orbit,
    *,
    bbox: geodetic.Box | None = None,
) -> Pass | None:
    """Get the properties of a swath of an half-orbit.
    Args:
        pass_number: Pass number
        orbit: Orbit describing the pass to be calculated.
        bbox: Bounding box of the pass. Defaults to the whole Earth.
    Returns:
        The properties of the pass.
    """
    index = pass_number - 1
    # Selected indexes corresponding to the current pass
    if index == len(orbit.pass_time) - 1:
        indexes = numpy.where(orbit.time >= orbit.pass_time[-1])[0]
    else:
        indexes = numpy.where((orbit.time >= orbit.pass_time[index])
                              & (orbit.time < orbit.pass_time[index + 1]))[0]
    if len(indexes) < 5:
        return None
    lon_nadir = orbit.longitude[indexes]
    lat_nadir = orbit.latitude[indexes]
    time = orbit.time[indexes]
    x_al = orbit.x_al[indexes]
    # Selects the orbit in the defined box
    if bbox is not None:
        mask = bbox.covered_by(
            lon_nadir,  # type: ignore[arg-type]
            lat_nadir,  # type: ignore[arg-type]
        )
        if numpy.all(~mask):
            return None
        if numpy.any(mask):
            lon_nadir = lon_nadir[mask]
            lat_nadir = lat_nadir[mask]
            time = time[mask]
            x_al = x_al[mask]
    equator_coordinates = _equator_properties(lon_nadir, lat_nadir, time)
    return Pass(
        lon_nadir,
        lat_nadir,
        time,
        x_al,
        equator_coordinates,
    ) 
[docs]
def calculate_swath(
    half_orbit: Pass,
    *,
    across_track_resolution: float | None = None,
    along_track_resolution: float | None = None,
    half_swath: float | None = None,
    half_gap: float | None = None,
    spheroid: geodetic.Spheroid | None = None,
) -> Swath:
    """Get the properties of a swath of an half-orbit.
    Args:
        half_orbit: Half-orbit used to calculate the swath.
        bbox: Bounding box of the pass. Defaults to the whole Earth.
        across_track_resolution: Distance, in km, between two points across
            track direction. Defaults to 2 km.
        along_track_resolution: Distance, in km, between two points along track
            direction. Defaults to 2 km.
        half_swath: Distance, in km, between the nadir and the center of the
            last pixel of the swath. Defaults to 70 km.
        half_gap: Distance, in km, between the nadir and the center of the first
            pixel of the swath. Defaults to 2 km.
        spheroid: The spheroid to use for the calculation. Defaults to ``None``,
            which means the WGS-84 spheroid is used.
    Returns:
        The properties of the pass.
    """
    across_track_resolution = across_track_resolution or 2.0
    along_track_resolution = along_track_resolution or 2
    half_swath = half_swath or 70.0
    half_gap = half_gap or 2.0
    # Compute across track distances from nadir
    # Number of points in half of the swath
    half_swath = int((half_swath - half_gap) / across_track_resolution) + 1
    x_ac = numpy.arange(half_swath,
                        dtype=float) * along_track_resolution + half_gap
    x_ac = numpy.hstack((-numpy.flip(x_ac), x_ac)) * 1e3
    x_ac = numpy.full((len(half_orbit), x_ac.size), x_ac)
    lon, lat = core.geodetic.calculate_swath(
        half_orbit.lon_nadir,  # type: ignore[arg-type]
        half_orbit.lat_nadir,  # type: ignore[arg-type]
        across_track_resolution * 1e3,
        half_gap * 1e3,
        half_swath,
        spheroid,
    )
    return Swath(
        half_orbit.lon_nadir,
        half_orbit.lat_nadir,
        half_orbit.time,
        half_orbit.x_al,
        half_orbit.equator_coordinates,
        lon,
        lat,
        x_ac,
    )