# Copyright (c) 2024 CNES
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""
XArray
------
Build interpolation objects from xarray.DataArray instances
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import pickle
from .. import cf, core, grid, interpolator
if TYPE_CHECKING:
from collections.abc import Hashable
import numpy as np
import xarray as xr
__all__ = ['Grid2D', 'Grid3D', 'Grid4D', 'RegularGridInterpolator']
class AxisIdentifier:
"""Identification of the axes defining longitudes, latitudes in a CF file.
Args:
data_array: The data array to be identified.
"""
def __init__(self, data_array: xr.DataArray):
self.data_array = data_array
def _axis(self, units: cf.AxisUnit) -> str | None:
"""Returns the name of the dimension that defines an axis.
Args:
units: The units of the axis
Returns:
The name of the coordinate
"""
for name, coord in self.data_array.coords.items():
if hasattr(coord, 'units') and coord.units in units:
return name # type: ignore[return-value]
return None
def longitude(self) -> str | None:
"""Returns the name of the dimension that defines a longitude axis.
Returns:
The name of the longitude coordinate
"""
return self._axis(cf.AxisLongitudeUnit())
def latitude(self) -> str | None:
"""Returns the name of the dimension that defines a latitude axis.
Returns:
The name of the latitude coordinates
"""
return self._axis(cf.AxisLatitudeUnit())
def _dims_from_data_array(data_array: xr.DataArray,
geodetic: bool,
ndims: int | None = 2) -> tuple[str, str]:
"""Gets the name of the dimensions that define the grid axes. the
longitudes and latitudes of the data array.
Args:
data_array: Provided data array
geodetic: True, if the axes of the grid represent longitudes and
latitudes otherwise Cartesian
ndims: Number of dimension expected for the variable
Returns:
Longitude and latitude names
Raises:
ValueError if the provided data array doesn't define a
longitude/latitude axis
ValueError if the number of dimensions is different from the number of
dimensions of the grid provided.
"""
size = len(data_array.shape)
if size != ndims:
raise ValueError(
'The number of dimensions of the variable is incorrect. Expected '
f'{ndims}, found {size}.')
if not geodetic:
return tuple( # type: ignore[return-value]
data_array.coords)[:2]
ident = AxisIdentifier(data_array)
lon = ident.longitude()
if lon is None:
raise ValueError("The dataset doesn't define a longitude axis")
lat = ident.latitude()
if lat is None:
raise ValueError("The dataset doesn't define a latitude axis")
return lon, lat
def _coords(
coords: dict,
dims: tuple,
datetime64: tuple[Hashable, core.TemporalAxis] | None = None,
) -> tuple:
"""Get the list of arguments to provide to the grid interpolation
functions.
Args:
coords: Mapping from dimension names to the new coordinates. New
coordinate can be a scalar, array-like.
dims: List of dimensions handled by the grid
datetime64: Properties of the axis used
Returns:
The tuple of arguments decoded.
Raises:
TypeError if coords are not one instance of ``dict``
IndexError if the number of coordinates is different from the
number of grid dimensions
IndexError if one of the coordinates is not used by this grid.
"""
if not isinstance(coords, dict):
raise TypeError('coords must be an instance of dict')
if len(coords) != len(dims):
raise IndexError('too many indices for array')
unknown = set(coords) - set(dims)
if unknown:
raise IndexError('axes not handled by this grid: ' +
', '.join([str(item) for item in unknown]))
# Is it necessary to manage a time axis?
if datetime64 is not None:
temporal_dim, temporal_axis = datetime64
return tuple(coords[dim] if dim != temporal_dim else temporal_axis.
safe_cast(coords[temporal_dim]) for dim in dims)
return tuple(coords[dim] for dim in dims)
[docs]
class Grid2D(grid.Grid2D):
"""Builds a Grid2D from the Xarray data provided.
Args:
data_array: Provided data
increasing_axes: If this is true, check that the grid axes are
increasing: the decreasing axes and the supplied grid will be
flipped. Default to ``False``.
geodetic: True, if the axes of the grid represent longitudes and
latitudes. In this case, the constructor will try to determine
the axes of longitudes and latitudes according to the value of
the attribute ``units`` using the following algorithm:
* if the axis unit is one of the values of the set ``degrees_east``,
``degree_east``, ``degree_E``, ``degrees_E``, ``degreeE`` or
``degreesE`` the axis
represents a longitude,
* if the axis unit is one of the values of the set
``degrees_north``, ``degree_north``, ``degree_N``, ``degrees_N``
or ``degreesN`` the axis represents a latitude.
If this option is false, the axes will be considered Cartesian.
Default to ``True``.
Raises:
ValueError: if the provided data array doesn't define a
longitude/latitude axis if ``geodetic`` is True.
ValueError: if the number of dimensions is different of 2.
"""
def __init__(self,
data_array: xr.DataArray,
increasing_axes: bool = False,
geodetic: bool = True):
self._dims = _dims_from_data_array(data_array, geodetic)
super().__init__(
core.Axis(data_array.coords[self._dims[0]].values,
is_circle=geodetic),
core.Axis(data_array.coords[self._dims[1]].values),
data_array.transpose(*self._dims).values,
increasing_axes='inplace' if increasing_axes else None)
[docs]
def bivariate(self, coords: dict, *args, **kwargs) -> np.ndarray:
"""Evaluate the interpolation defined for the given coordinates.
Args:
coords: Mapping from dimension names to the
coordinates to interpolate. Coordinates must be array-like.
*args: List of arguments provided to the interpolation
method :py:meth:`pyinterp.bivariate <pyinterp.bivariate>`
**kwargs: List of keywords arguments provided to the interpolation
method :py:meth:`pyinterp.bivariate <pyinterp.bivariate>`
Returns:
The interpolated values.
"""
return interpolator.bivariate(self, *_coords(coords, self._dims),
*args, **kwargs)
[docs]
def bicubic(self, coords: dict, *args, **kwargs) -> np.ndarray:
"""Evaluate the interpolation defined for the given coordinates.
Args:
coords: Mapping from dimension names to the
coordinates to interpolate. Coordinates must be array-like.
*args: List of arguments provided to the interpolation
method :py:meth:`pyinterp.bicubic <pyinterp.bicubic>`.
**kwargs: List of keyword arguments provided to the interpolation
method :py:meth:`pyinterp.bicubic <pyinterp.bicubic>`.
Returns:
The interpolated values.
"""
return interpolator.bicubic(self, *_coords(coords, self._dims), *args,
**kwargs)
[docs]
class Grid3D(grid.Grid3D):
"""Builds a Grid3D from the Xarray data provided.
Args:
data_array: Provided data array
increasing_axes: If this is true, check that the grid axes are
increasing: the decreasing axes and the supplied grid will be
flipped. Default to ``False``.
geodetic: True, if the axes of the grid represent longitudes and
latitudes. In this case, the constructor will try to determine
the axes of longitudes and latitudes according to the value of
the attribute ``units`` using the following algorithm:
* if the axis unit is one of the values of the set
``degrees_east``, ``degree_east``, ``degree_E``, ``degrees_E``,
``degreeE`` or ``degreesE`` the axis represents a longitude,
* if the axis unit is one of the values of the set
``degrees_north``, ``degree_north``, ``degree_N``, ``degrees_N``
or ``degreesN`` the axis represents a latitude.
If this option is false, the axes will be considered Cartesian.
Default to ``True``.
Raises:
ValueError: if the provided data array doesn't define a
longitude/latitude axis if ``geodetic`` is True.
ValueError: if the number of dimensions is different of 3.
"""
def __init__(self,
data_array: xr.DataArray,
increasing_axes: bool = False,
geodetic: bool = True):
x, y = _dims_from_data_array(data_array, geodetic, ndims=3)
z = (set(data_array.dims) - {x, y}).pop()
self._dims = (x, y, z)
# Should the grid manage a time axis?
dtype = data_array.coords[z].dtype
self._datetime64: tuple[Hashable, core.TemporalAxis] | None
if 'datetime64' in dtype.name or 'timedelta64' in dtype.name:
self._datetime64 = z, core.TemporalAxis(
data_array.coords[z].values)
else:
self._datetime64 = None
super().__init__(
core.Axis(data_array.coords[x].values, is_circle=geodetic),
core.Axis(data_array.coords[y].values),
core.Axis(data_array.coords[z].values)
if self._datetime64 is None else self._datetime64[1],
data_array.transpose(x, y, z).values,
increasing_axes='inplace' if increasing_axes else None)
[docs]
def trivariate(self, coords: dict, *args, **kwargs) -> np.ndarray:
"""Evaluate the interpolation defined for the given coordinates.
Args:
coords: Mapping from dimension names to the coordinates to
interpolate. Coordinates must be array-like.
*args: List of arguments provided to the interpolation
method :py:meth:`pyinterp.trivariate
<pyinterp.trivariate>`.
**kwargs: List of keywords arguments provided to the interpolation
method :py:meth:`pyinterp.trivariate
<pyinterp.trivariate>`.
Returns:
The interpolated values.
"""
return interpolator.trivariate(
self, *_coords(coords, self._dims, self._datetime64), *args,
**kwargs)
[docs]
def bicubic(self, coords: dict, *args, **kwargs) -> np.ndarray:
"""Evaluate the interpolation defined for the given coordinates.
Args:
coords: Mapping from dimension names to the coordinates to
interpolate. Coordinates must be array-like.
*args: List of arguments provided to the interpolation
method :py:meth:`pyinterp.bicubic <pyinterp.bicubic>`.
**kwargs: List of keyword arguments provided to the interpolation
method :py:meth:`pyinterp.bicubic <pyinterp.bicubic>`.
Returns:
The interpolated values.
"""
return interpolator.bicubic(
self, *_coords(coords, self._dims, self._datetime64), *args,
**kwargs)
[docs]
class Grid4D(grid.Grid4D):
"""Builds a Grid4D from the Xarray data provided.
Args:
data_array: Provided data array.
increasing_axes: If this is true, check that the grid axes are
increasing: the decreasing axes and the supplied grid will be
flipped. Default to ``False``.
geodetic: True, if the axes of the grid represent longitudes and
latitudes. In this case, the constructor will try to determine the
axes of longitudes and latitudes according to the value of the
attribute ``units`` using the following algorithm:
* if the axis unit is one of the values of the set
``degrees_east``, ``degree_east``, ``degree_E``, ``degrees_E``,
``degreeE`` or ``degreesE`` the axis represents a longitude,
* if the axis unit is one of the values of the set
``degrees_north``, ``degree_north``, ``degree_N``, ``degrees_N``
or ``degreesN`` the axis represents a latitude.
If this option is false, the axes will be considered Cartesian.
Default to ``True``.
Raises:
ValueError: if the provided data array doesn't define a
longitude/latitude axis if ``geodetic`` is True.
ValueError: if the number of dimensions is different of 4.
"""
def __init__(self,
data_array: xr.DataArray,
increasing_axes: bool = False,
geodetic: bool = True):
x, y = _dims_from_data_array(data_array, geodetic, ndims=4)
z, u = tuple(set(data_array.dims) - {x, y})
# Should the grid manage a time axis?
self._datetime64 = None
dtype = data_array.coords[z].dtype
if 'datetime64' in dtype.name:
self._datetime64 = z, core.TemporalAxis(
data_array.coords[z].values)
dtype = data_array.coords[u].dtype
if 'datetime64' in dtype.name:
if self._datetime64 is not None:
raise ValueError('unable to handle two time axes')
self._datetime64 = u, core.TemporalAxis(
data_array.coords[u].values)
# The time axis is the Z axis.
z, u = u, z
# Names of the dimensions in the order of the tensor handled by the
# library
self._dims = (x, y, z, u)
super().__init__(
core.Axis(data_array.coords[x].values, is_circle=geodetic),
core.Axis(data_array.coords[y].values),
core.Axis(data_array.coords[z].values)
if self._datetime64 is None else self._datetime64[1],
core.Axis(data_array.coords[u].values),
data_array.transpose(x, y, z, u).values,
increasing_axes='inplace' if increasing_axes else None)
[docs]
def quadrivariate(self, coords: dict, *args, **kwargs) -> np.ndarray:
"""Evaluate the interpolation defined for the given coordinates.
Args:
coords: Mapping from dimension names to the
coordinates to interpolate. Coordinates must be array-like.
*args: List of arguments provided to the interpolation
method :py:meth:`pyinterp.quadrivariate
<pyinterp.quadrivariate>`.
**kwargs: List of keywords arguments provided to the interpolation
method :py:meth:`pyinterp.quadrivariate
<pyinterp.quadrivariate>`.
Returns:
The interpolated values.
"""
return interpolator.quadrivariate(
self, *_coords(coords, self._dims, self._datetime64), *args,
**kwargs)
[docs]
def bicubic(self, coords: dict, *args, **kwargs) -> np.ndarray:
"""Evaluate the interpolation defined for the given coordinates.
Args:
coords: Mapping from dimension names to the coordinates to
interpolate. Coordinates must be array-like.
*args: List of arguments provided to the interpolation
method :py:meth:`pyinterp.bicubic <pyinterp.bicubic>`.
**kwargs: List of keyword arguments provided to the interpolation
method :py:meth:`pyinterp.bicubic <pyinterp.bicubic>`.
Returns:
The interpolated values.
"""
return interpolator.bicubic(
self, *_coords(coords, self._dims, self._datetime64), *args,
**kwargs)
[docs]
class RegularGridInterpolator:
"""Interpolation on a regular grid in arbitrary dimensions.
The data must be defined on a regular grid; the grid spacing however may be
uneven. Linear, nearest neighbors, inverse distance weighting and bicubic
interpolation are supported.
Args:
array: The array defining the regular grid in ``n`` dimensions.
increasing_axes: If this is true, check that the grid axes are
increasing: the decreasing axes and the supplied grid will be
flipped. Default to ``False``.
geodetic: True, if the axes of the grid represent longitudes and
latitudes. In this case, the constructor will try to determine the
axes of longitudes and latitudes according to the value of the
attribute ``units`` using the following algorithm:
* if the axis unit is one of the values of the set ``degrees_east``,
``degree_east``, ``degree_E``,
``degrees_E``, ``degreeE`` or ``degreesE`` the axis represents a
longitude,
* if the axis unit is one of the values of the set
``degrees_north``, ``degree_north``, ``degree_N``, ``degrees_N``
or ``degreesN`` the axis represents a latitude.
If this option is false, the axes will be considered Cartesian.
Default to ``True``.
Raises:
ValueError: if the provided data array doesn't define a
longitude/latitude axis if ``geodetic`` is True.
NotImplementedError: if the number of dimensions in the array is
less than 2 or more than 4.
"""
def __init__(self,
array: xr.DataArray,
increasing_axes: bool = True,
geodetic: bool = True):
if len(array.shape) == 2:
self._grid: (Grid2D | Grid3D | Grid4D) = Grid2D(
array, increasing_axes=increasing_axes, geodetic=geodetic)
self._interp = self._grid.bivariate
elif len(array.shape) == 3:
self._grid = Grid3D(array,
increasing_axes=increasing_axes,
geodetic=geodetic)
self._interp = self._grid.trivariate
elif len(array.shape) == 4:
self._grid = Grid4D(array,
increasing_axes=increasing_axes,
geodetic=geodetic)
self._interp = self._grid.quadrivariate
else:
raise NotImplementedError(
'Only the 2D, 3D or 4D grids can be interpolated.')
[docs]
def __getstate__(self) -> tuple[bytes]:
# Walk around a bug with pybind11 and pickle starting with Python 3.9
# Serialize the object here with highest protocol.
return (pickle.dumps((self._grid, self._interp),
protocol=pickle.HIGHEST_PROTOCOL), )
[docs]
def __setstate__(self, state: tuple[bytes]) -> None:
# Walk around a bug with pybind11 and pickle starting with Python 3.9
# Deserialize the object here with highest protocol.
self._grid, self._interp = pickle.loads(state[0])
@property
def ndim(self) -> int:
"""Gets the number of array dimensions.
Returns:
Number of array dimensions.
"""
return self._grid.array.ndim
@property
def grid(self) -> Grid2D | Grid3D | Grid4D:
"""Gets the instance of handling the regular grid for interpolations.
Returns:
The regular grid.
"""
return self._grid
[docs]
def __call__(self,
coords: dict,
method: str = 'bilinear',
bounds_error: bool = False,
bicubic_kwargs: dict | None = None,
num_threads: int = 0,
**kwargs) -> np.ndarray:
"""Interpolation at coordinates.
Args:
coords: Mapping from dimension names to the new coordinates.
New coordinate can be an scalar, array-like.
method: The method of interpolation to perform. Supported are
``bicubic``, ``bilinear``, ``nearest``, and
``inverse_distance_weighting``. Default to ``bilinear``.
bounds_error: If True, when interpolated values are requested
outside of the domain of the input data, a
:py:class:`ValueError` is raised. If False, then `nan` is used.
bicubic_kwargs: A dictionary of keyword arguments to pass on to the
:py:func:`bicubic <pyinterp.bicubic>` function. This is useful
to control the parameters of this interpolator: window size in
x, y and the edge control of the calculation windows.
num_threads: The number of threads to use for the computation. If 0
all CPUs are used. If 1 is given, no parallel computing code is
used at all, which is useful for debugging. Defaults to ``0``.
**kwargs: List of keyword arguments provided to the interpolation
method :py:meth:`pyinterp.bivariate <pyinterp.bivariate>`,
:py:meth:`pyinterp.trivariate <pyinterp.trivariate>` or
:py:meth:`pyinterp.quadrivariate <pyinterp.quadrivariate>`
depending on the number of dimensions of the grid.
Returns:
New array on the new coordinates.
"""
if method == 'bicubic':
bicubic_kwargs = bicubic_kwargs or {}
return self._grid.bicubic(coords,
bounds_error=bounds_error,
num_threads=num_threads,
**bicubic_kwargs)
return self._interp(coords,
interpolator=method,
bounds_error=bounds_error,
num_threads=num_threads,
**kwargs)