Source code for pyinterp.backends.xarray

# Copyright (c) 2026 CNES.
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""XArray backend.

Build interpolation objects from xarray.DataArray instances
"""

from __future__ import annotations

import dataclasses
from typing import TYPE_CHECKING, Any, cast

import numpy as np

from .. import cf, core
from ..regular_grid_interpolator import (
    InterpolationMethods,
    bivariate,
    quadrivariate,
    trivariate,
)


if TYPE_CHECKING:
    from collections.abc import Callable, Hashable, Iterable

    import xarray as xr

    from ..type_hints import (
        NDArray1D,
        NDArray1DDateTime64,
        NDArray1DFloat64,
        NDArray1DNumeric,
        NDArray1DNumericWithTime,
    )

__all__ = ["Grid2D", "Grid3D", "Grid4D"]

#: Two dimensional grid.
TWO_DIMENSIONS = 2

#: Three dimensional grid.
THREE_DIMENSIONS = 3

#: Four dimensional grid.
FOUR_DIMENSIONS = 4

#: Index of the longitude axis in a 2D, 3D, or 4D grid.
LONGITUDE_AXIS_INDEX = 0

#: Index of the temporal axis in a 3D or 4D grid.
TEMPORAL_AXIS_INDEX = 2


class AxisIdentifier:
    """Identify axes defining longitudes and latitudes in a CF file.

    This class determines which dimensions in a data array correspond to
    longitude and latitude coordinates based on CF conventions.

    Args:
        data_array: The data array to be identified.

    """

    def __init__(self, data_array: xr.DataArray) -> None:
        """Initialize the AxisIdentifier with the provided data array."""
        self.data_array = data_array

    def _axis(self, units: cf.AxisUnit) -> str | None:
        """Return the name of the dimension that defines an axis.

        Args:
            units: The units of the axis

        Returns:
            The name of the coordinate

        """
        for name, coord in self.data_array.coords.items():
            if hasattr(coord, "units") and coord.units in units:
                return str(name)
        return None

    def longitude(self) -> str | None:
        """Return the name of the dimension that defines a longitude axis.

        Returns:
            The name of the longitude coordinate

        """
        return self._axis(cf.AxisLongitudeUnit())

    def latitude(self) -> str | None:
        """Return the name of the dimension that defines a latitude axis.

        Returns:
            The name of the latitude coordinates

        """
        return self._axis(cf.AxisLatitudeUnit())


def _identify_temporal_axis(
    data_array: xr.DataArray,
    dims: Iterable[Hashable],
) -> Hashable | None:
    """Identify the temporal axis in the data array."""
    for dim in dims:
        # Check coordinate associated with the dimension
        if dim in data_array.coords:
            coord = data_array.coords[dim]
            # Robust datetime check using numpy
            if np.issubdtype(
                coord.dtype,
                np.datetime64,
            ) or np.issubdtype(
                coord.dtype,
                np.timedelta64,
            ):
                # Support is limited to a single temporal axis; return as soon
                # as one is found. Any additional temporal axes will be
                # disregarded.
                return dim
    return None


@dataclasses.dataclass(frozen=True)
class _DimInfo:
    """Information about a dimension in the data array."""

    #: Data array
    _data_array: xr.DataArray
    #: Tuple of dimension names in standard order.
    dims: tuple[Hashable, ...]
    #: True if longitude was identified (at index 0).
    has_longitude: bool = False
    #: True if temporal axis was identified (at index 2).
    has_temporal: bool = False
    #: Indicates whether the dimension names differ in order from those in the
    #: provided data array.
    should_be_transposed: bool = False

    def axis(self, index: int) -> core.Axis | core.TemporalAxis:
        """Get dimension name at the specified index."""
        values = self.data_array.coords[self.dims[index]].values
        if index == LONGITUDE_AXIS_INDEX and self.has_longitude:
            return core.Axis(values, period=360.0)
        if index == TEMPORAL_AXIS_INDEX and self.has_temporal:
            return core.TemporalAxis(values)
        return core.Axis(values)

    @property
    def data_array(self) -> xr.DataArray:
        """Get the associated data array."""
        if self.should_be_transposed:
            return self._data_array.transpose(*self.dims)
        return self._data_array

    @property
    def datetime64(self) -> Hashable:
        """Get the temporal axis information if present."""
        if not self.has_temporal:
            raise AttributeError("No temporal axis present")
        return self.dims[2]


def _get_canonical_dimensions(
    data_array: xr.DataArray,
    ndims: int = 2,
) -> _DimInfo:
    """Get the name of dimensions that define the grid axes in canonical order.

    Identifies longitude, latitude, and temporal axes using CF conventions.
    Returns dimensions ordered as (Longitude, Latitude, Time, ...Others) to
    standardize grid processing.

    Target positions:
    - Index 0: Longitude (if present)
    - Index 1: Latitude (if present)
    - Index 2: Temporal axis (if present and ndims >= 3)

    Args:
        data_array: Provided data array.
        ndims: Number of dimensions expected for the variable.

    Returns:
        A _DimInfo instance containing ordered dimension names and flags
        indicating presence of longitude and temporal axes.

    Raises:
        ValueError: If the number of dimensions doesn't match ndims.

    """
    if data_array.ndim != ndims:
        raise ValueError(
            f"The number of dimensions of the variable is incorrect. "
            f"Expected {ndims}, found {data_array.ndim}."
        )

    current_dims = list(data_array.dims)

    # Identify lon/lat axes
    ident = AxisIdentifier(data_array)
    lon_dim = ident.longitude()
    lat_dim = ident.latitude()

    # Identify temporal axis (only one supported)
    time_dim = _identify_temporal_axis(data_array, current_dims)

    has_longitude = lon_dim is not None
    has_temporal = False

    special_dims = {lon_dim, lat_dim, time_dim}
    remaining_dims = [d for d in current_dims if d not in special_dims]

    final_dims: list[Hashable] = []

    # Slot 0: Longitude
    if has_longitude:
        final_dims.append(lon_dim)
    elif remaining_dims:
        final_dims.append(remaining_dims.pop(0))

    # Slot 1: Latitude
    if lat_dim is not None:
        final_dims.append(lat_dim)
    elif remaining_dims:
        final_dims.append(remaining_dims.pop(0))

    # Slot 2 : Time
    if ndims >= THREE_DIMENSIONS:
        if time_dim is not None:
            final_dims.append(time_dim)
            has_temporal = True
        elif remaining_dims:
            final_dims.append(remaining_dims.pop(0))

    # Fill remaining slots with whatever is left
    final_dims.extend(remaining_dims)

    # Validate we have the right number of dimensions
    assert len(final_dims) == ndims

    return _DimInfo(
        data_array,
        tuple(final_dims),
        has_longitude,
        has_temporal,
        tuple(current_dims) != tuple(final_dims),
    )


def _coords(
    coords: dict[Hashable, NDArray1D],
    dims: tuple[Hashable, ...],
    datetime64: tuple[Hashable, core.TemporalAxis] | None = None,
) -> tuple[NDArray1D | NDArray1DDateTime64, ...]:
    """Get the list of arguments to provide to grid interpolation functions."""
    if not isinstance(coords, dict):
        raise TypeError("coords must be an instance of dict")
    if len(coords) != len(dims):
        raise IndexError(
            f"Number of coordinates ({len(coords)}) doesn't match "
            f"number of dimensions ({len(dims)})"
        )
    unknown = set(coords) - set(dims)
    if unknown:
        raise IndexError(
            "axes not handled by this grid: "
            + ", ".join([str(item) for item in unknown])
        )

    # Is it necessary to manage a time axis?
    if datetime64 is not None:
        temporal_dim, temporal_axis = datetime64
        result: list[NDArray1D | NDArray1DDateTime64] = []
        for dim in dims:
            coord_value = coords[dim]
            if dim != temporal_dim:
                # Regular coordinate
                result.append(cast("NDArray1D", coord_value))
            else:
                # Cast temporal coordinates
                result.append(
                    temporal_axis.cast_to_temporal_axis(
                        cast("NDArray1DDateTime64", coord_value)
                    )
                )
        return tuple(result)

    # No temporal axis - cast all to NDArray1D
    return tuple(cast("NDArray1D", coords[dim]) for dim in dims)


class _GridHolder:
    """Base class for grid holders."""

    def __init__(
        self,
        grid: core.GridHolder,
        dims: tuple[Hashable, ...],
    ) -> None:
        """Initialize the grid holder."""
        self._dims = dims
        self._instance = grid
        self._datetime64: tuple[Hashable, core.TemporalAxis] | None = None
        if self._instance.has_temporal_axis:
            self._datetime64 = (
                dims[2],
                cast("core.TemporalAxis", self._instance.z),
            )

    def __getattr__(self, name: str) -> Any:  # noqa: ANN401
        """Delegate attribute access to the underlying grid instance."""
        return getattr(self._instance, name)

    def __repr__(self) -> str:
        return repr(self._instance)


[docs] class Grid2D(_GridHolder): """Build a Grid2D from Xarray data. Create a 2D grid interpolation object from the provided Xarray data array, with optional axis ordering and geodetic coordinate support. Args: data_array: Provided data Raises: ValueError: if the number of dimensions is different of 2. """ def __init__(self, data_array: xr.DataArray) -> None: """Initialize the 2D grid from an Xarray data array.""" canonical_dimensions = _get_canonical_dimensions( data_array, ndims=TWO_DIMENSIONS ) grid = core.Grid( cast("core.Axis", canonical_dimensions.axis(0)), cast("core.Axis", canonical_dimensions.axis(1)), canonical_dimensions.data_array.values, ) super().__init__(grid, canonical_dimensions.dims)
[docs] def bivariate( self, coords: dict[Hashable, NDArray1DNumeric], method: InterpolationMethods = "bilinear", **kwargs: Any, # noqa: ANN401 ) -> np.ndarray: """Evaluate the interpolation defined for the given coordinates. Args: coords: Mapping from dimension names to the coordinates to interpolate. Coordinates must be array-like. method: Interpolation method. See :py:func:`pyinterp.bivariate` for more details. **kwargs: Additional keyword arguments provided to the interpolation method. Returns: The interpolated values. Raises: IndexError: If coordinate dimensions don't match grid dimensions """ x, y = _coords(coords, self._dims, self._datetime64) return bivariate( self._instance, cast("NDArray1DFloat64", x), cast("NDArray1DFloat64", y), method=method, **kwargs, )
[docs] class Grid3D(_GridHolder): """Build a Grid3D from Xarray data. Create a 3D grid interpolation object from the provided Xarray data array. Supports temporal axes via datetime64 coordinates. Args: data_array: Provided 3D data array Raises: ValueError: if the number of dimensions is different from 3. """ def __init__(self, data_array: xr.DataArray) -> None: """Initialize the 3D grid from an Xarray data array.""" canonical_dimensions = _get_canonical_dimensions( data_array, ndims=THREE_DIMENSIONS ) grid = core.Grid( cast( "core.Axis", canonical_dimensions.axis(0), ), cast( "core.Axis", canonical_dimensions.axis(1), ), cast( "core.Axis | core.TemporalAxis", canonical_dimensions.axis(2), ), canonical_dimensions.data_array.values, ) super().__init__(grid, canonical_dimensions.dims)
[docs] def trivariate( self, coords: dict[Hashable, NDArray1DNumericWithTime], method: InterpolationMethods = "bilinear", **kwargs: Any, # noqa: ANN401 ) -> np.ndarray: """Evaluate the interpolation defined for the given coordinates. Args: coords: Mapping from dimension names to the coordinates to interpolate. Coordinates must be array-like. If the third axis is temporal, provide datetime64 array. method: Interpolation method. See :py:func:`pyinterp.trivariate` for more details. **kwargs: Additional keyword arguments provided to the interpolation method. Returns: The interpolated values. Raises: IndexError: If coordinate dimensions don't match grid dimensions """ x, y, z = _coords(coords, self._dims, self._datetime64) return trivariate( self._instance, cast("NDArray1DFloat64", x), cast("NDArray1DFloat64", y), cast("NDArray1DFloat64 | NDArray1DDateTime64", z), method=method, **kwargs, )
[docs] class Grid4D(_GridHolder): """Build a Grid4D from Xarray data. Create a 4D grid interpolation object from the provided Xarray data array. Supports temporal axes via datetime64 coordinates. Args: data_array: Provided 4D data array Raises: ValueError: if the number of dimensions is different from 4. """ def __init__(self, data_array: xr.DataArray) -> None: """Initialize the 4D grid from an Xarray data array.""" canonical_dimensions = _get_canonical_dimensions( data_array, ndims=FOUR_DIMENSIONS ) grid = core.Grid( cast( "core.Axis", canonical_dimensions.axis(0), ), cast( "core.Axis", canonical_dimensions.axis(1), ), cast( "core.Axis | core.TemporalAxis", canonical_dimensions.axis(2), ), cast( "core.Axis", canonical_dimensions.axis(3), ), canonical_dimensions.data_array.values, ) super().__init__(grid, canonical_dimensions.dims)
[docs] def quadrivariate( self, coords: dict[Hashable, NDArray1DNumericWithTime], method: InterpolationMethods = "bilinear", **kwargs: Any, # noqa: ANN401 ) -> np.ndarray: """Evaluate the interpolation defined for the given coordinates. Args: coords: Mapping from dimension names to the coordinates to interpolate. Coordinates must be array-like. If the third axis is temporal, provide datetime64 array. method: Interpolation method. See :py:func:`pyinterp.quadrivariate` for more details. **kwargs: Additional keyword arguments provided to the interpolation method. Returns: The interpolated values. Raises: IndexError: If coordinate dimensions don't match grid dimensions """ x, y, z, u = _coords(coords, self._dims, self._datetime64) return quadrivariate( self._instance, cast("NDArray1DFloat64", x), cast("NDArray1DFloat64", y), cast("NDArray1DFloat64 | NDArray1DDateTime64", z), cast("NDArray1DFloat64", u), method=method, **kwargs, )
[docs] class RegularGridInterpolator: """Interpolate on a regular grid in arbitrary dimensions. Perform interpolation on a regular grid with uneven spacing support. Automatically detects geodetic coordinates (lon/lat) using CF conventions and temporal axes (datetime64). The data must be defined on a regular grid; the grid spacing however may be uneven. Linear, nearest neighbors, inverse distance weighting, and bicubic interpolation are supported. Args: array: The xarray DataArray defining the regular grid in ``n`` dimensions. Must be 2D, 3D, or 4D. Raises: NotImplementedError: if the number of dimensions in the array is less than 2 or more than 4. Notes: **Automatic Detection:** The interpolator automatically detects: - **Geodetic coordinates**: If lon/lat are found via CF conventions (units attribute). - **Temporal axes**: If a coordinate has dtype='datetime64', it will be treated as a temporal axis with proper interpolation - **Dimension count**: Automatically selects Grid2D, Grid3D, or Grid4D **Geodetic Detection (CF Conventions):** Longitude axes are detected if the coordinate has units attribute matching: ``degrees_east``, ``degree_east``, ``degree_E``, ``degrees_E``, ``degreeE``, or ``degreesE`` Latitude axes are detected if the coordinate has units attribute matching: ``degrees_north``, ``degree_north``, ``degree_N``, ``degrees_N``, ``degreeN``, or ``degreesN`` **Temporal Detection:** Any coordinate with dtype containing 'datetime64' is automatically treated as a temporal axis. Examples: 2D sea surface temperature >>> sst = xr.open_dataarray("sst.nc") # (lon, lat) >>> interp = RegularGridInterpolator(sst) >>> result = interp( ... {"lon": [10.5, 20.3], "lat": [45.2, -30.1]}, method="bilinear" ... ) 3D ocean temperature with depth >>> temp = xr.open_dataarray("temp.nc") # (lon, lat, depth) >>> interp = RegularGridInterpolator(temp) >>> result = interp( ... {"lon": [10.5], "lat": [45.2], "depth": [25.0]}, ... method="bilinear", ... ) 3D SST time series (automatic temporal handling) >>> sst_time = xr.open_dataarray("sst_time.nc") # (lon, lat, time) >>> interp = RegularGridInterpolator(sst_time) >>> result = interp( ... { ... "lon": [10.5], ... "lat": [45.2], ... "time": np.array(["2020-01-01"], dtype="datetime64"), ... }, ... method="bilinear", ... ) """ def __init__(self, array: xr.DataArray) -> None: """Initialize the interpolator from an Xarray data array. Args: array: The xarray DataArray to interpolate. Must be 2D, 3D, or 4D. Raises: NotImplementedError: If array is not 2D, 3D, or 4D. """ ndim = len(array.shape) self._grid: Grid2D | Grid3D | Grid4D self._interp: Callable[..., Any] if ndim == TWO_DIMENSIONS: self._grid = Grid2D(array) self._interp = self._grid.bivariate elif ndim == THREE_DIMENSIONS: self._grid = Grid3D(array) self._interp = self._grid.trivariate elif ndim == FOUR_DIMENSIONS: self._grid = Grid4D(array) self._interp = self._grid.quadrivariate else: raise NotImplementedError( f"Only 2D, 3D, and 4D grids can be interpolated. " f"Got {ndim}D grid." ) @property def ndim(self) -> int: """Get the number of array dimensions. Returns: Number of array dimensions (2, 3, or 4). """ return len(self._grid._dims) @property def grid(self) -> Grid2D | Grid3D | Grid4D: """Get the instance handling the regular grid for interpolations. Returns: The underlying Grid2D, Grid3D, or Grid4D instance. """ return self._grid
[docs] def __call__( self, coords: dict, method: InterpolationMethods = "bilinear", **kwargs: Any, # noqa: ANN401 ) -> np.ndarray: """Interpolate at coordinates. Perform interpolation at the specified coordinates using the chosen method and parameters. Args: coords: Mapping from dimension names to the new coordinates. Coordinates can be scalars or array-like. For temporal axes, provide datetime64 arrays. method: The method of interpolation to perform. Supported methods depend on the grid type. Common methods include: - Geometric methods: ``nearest``, ``bilinear``, ``idw`` - Windowed methods: ``akima``, ``akima_periodic``, ``bicubic``, ``bilinear``, ``c_spline``, ``c_spline_not_a_knot``, ``c_spline_periodic``, ``linear``, ``polynomial``, ``steffen``. **kwargs: Additional keyword arguments passed to the interpolation function. Common options include: - ``bounds_error`` (bool): Raise error if coordinates are out of bounds. Default: False (returns NaN). - ``num_threads`` (int): Number of threads for parallel computation. 0 uses all CPUs. Default: 0. For windowed methods (bicubic, c_spline, etc.), additional options include: - ``half_window_size_x`` (int): Half window size in X direction - ``half_window_size_y`` (int): Half window size in Y direction - ``boundary_mode`` (str): Boundary handling mode (``"shrink"``, ``"undef"``) For 3D/4D grids: - ``third_axis`` (str): Method for 3rd axis (``"linear"``, ``"nearest"``) - ``fourth_axis`` (str): Method for 4th axis (``"linear"``, ``"nearest"``) Returns: Interpolated values as numpy array with same shape as input coordinate arrays. Raises: ValueError: If bounds_error=True and coordinates are out of bounds. IndexError: If coordinate dimensions don't match grid dimensions. Examples: Simple bilinear interpolation >>> result = interp( ... {"lon": [10.5], "lat": [45.2]}, method="bilinear" ... ) Bicubic with custom window size >>> result = interp( ... {"lon": [10.5], "lat": [45.2]}, ... method="bicubic", ... half_window_size_x=10, ... half_window_size_y=10, ... ) With bounds checking >>> result = interp( ... {"lon": [10.5], "lat": [45.2]}, ... method="bilinear", ... bounds_error=True, ... ) Multi-threaded >>> result = interp( ... {"lon": lon_array, "lat": lat_array}, ... method="bilinear", ... num_threads=4, ... ) 3D with temporal axis >>> result = interp( ... { ... "lon": [10.5], ... "lat": [45.2], ... "time": np.array(["2020-01-01"], dtype="datetime64"), ... }, ... method="bilinear", ... ) """ return self._interp(coords, method=method, **kwargs)