Source code for pyinterp.dask

# Copyright (c) 2026 CNES.
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""Distributed computation support for statistics using Dask.

This module provides functions to compute statistics on dask arrays
using the pyinterp statistics classes. Dask is an optional dependency.

Example usage:

    >>> import dask.array as da
    >>> import numpy as np
    >>> import pyinterp
    >>> import pyinterp.dask as dask_stats

    Create a dask array

    >>> x = da.random.random((10000,), chunks=1000)

    Compute descriptive statistics

    >>> stats = dask_stats.descriptive_statistics(x)
    >>> print(stats.mean())

    Compute quantiles using TDigest

    >>> digest = dask_stats.tdigest(x)
    >>> print(digest.quantile(0.5))
"""

from __future__ import annotations

import copy
from typing import TYPE_CHECKING

import numpy as np

from . import core


if TYPE_CHECKING:
    import dask.array

__all__ = [
    "binning1d",
    "binning2d",
    "descriptive_statistics",
    "histogram2d",
    "tdigest",
]


def _check_dask_available() -> None:
    """Check that dask is available."""
    try:
        import dask.array  # noqa: F401, PLC0415
    except ImportError as exc:
        msg = (
            "dask is required for distributed computation. "
            "Install it with: pip install dask[array]"
        )
        raise ImportError(msg) from exc


def _validate_dask_array(
    arr: object,
    name: str,
) -> dask.array.Array:
    """Validate that input is a dask array.

    Args:
        arr: Input array to validate.
        name: Name of the parameter for error messages.

    Returns:
        The validated dask array.

    Raises:
        TypeError: If the input is not a dask array.

    """
    import dask.array as da  # noqa: PLC0415 (to avoid import issues)

    if not isinstance(arr, da.Array):
        msg = f"{name} must be a dask array, got {type(arr).__name__}"
        raise TypeError(msg)
    return arr


def _validate_shapes_match(
    values: dask.array.Array,
    weights: dask.array.Array | None,
) -> None:
    """Validate that values and weights have matching shapes.

    Args:
        values: Values array.
        weights: Optional weights array.

    Raises:
        ValueError: If shapes don't match.

    """
    if weights is not None and values.shape != weights.shape:
        msg = (
            f"values and weights must have the same shape, "
            f"got {values.shape} and {weights.shape}"
        )
        raise ValueError(msg)


[docs] def descriptive_statistics( values: dask.array.Array, weights: dask.array.Array | None = None, axis: list[int] | None = None, *, dtype: str | type | np.dtype | None = None, ) -> core.DescriptiveStatisticsHolder: """Compute descriptive statistics on a dask array. This function computes statistics (mean, variance, skewness, kurtosis, etc.) on a dask array by processing each block independently and then merging the results. Args: values: Input dask array of values. weights: Optional dask array of weights with the same shape as values. axis: Axis or axes along which to compute statistics. If None, statistics are computed over all axes. dtype: Data type for computation. Can be "float32", "float64", np.float32, np.float64, or None (defaults to float64). Returns: A DescriptiveStatistics instance containing the computed statistics. Raises: ImportError: If dask is not installed. TypeError: If inputs are not dask arrays. ValueError: If values and weights have different shapes. Example: >>> import dask.array as da >>> import pyinterp.dask as dask_stats >>> values = da.random.random((10000,), chunks=1000) >>> stats = dask_stats.descriptive_statistics(values) >>> print(f"Mean: {stats.mean():.4f}") >>> print(f"Std: {np.sqrt(stats.variance()):.4f}") """ _check_dask_available() import dask.array as da # noqa: PLC0415 (to avoid import issues) values = _validate_dask_array(values, "values") if weights is not None: weights = _validate_dask_array(weights, "weights") _validate_shapes_match(values, weights) def _process_block( values_block: np.ndarray, weights_block: np.ndarray | None, axis: list[int] | None, dtype: str | type | np.dtype | None, block_id: tuple[int, ...] | None = None, ) -> np.ndarray: """Process a single block and return statistics as object array.""" stats = core.DescriptiveStatistics( values_block, weights=weights_block, axis=axis, dtype=dtype, ) result = np.empty((1,), dtype=object) result[0] = stats return result # Create dummy weights if needed to enable block alignment if weights is None: blocks = da.map_blocks( _process_block, values, None, axis, dtype, dtype=object, drop_axis=list(range(values.ndim)), new_axis=0, ) else: blocks = da.map_blocks( _process_block, values, weights, axis, dtype, dtype=object, drop_axis=list(range(values.ndim)), new_axis=0, ) # Compute all blocks and merge results = blocks.compute() # Merge all results using in-place addition merged = results[0] for item in results[1:]: merged += item return merged
[docs] def tdigest( values: dask.array.Array, weights: dask.array.Array | None = None, axis: list[int] | None = None, compression: int = 100, *, dtype: str | type | np.dtype | None = None, ) -> core.TDigestHolder: """Compute quantile estimates on a dask array using T-Digest. This function uses the T-Digest algorithm to compute approximate quantiles on a dask array by processing each block independently and then merging the results. Args: values: Input dask array of values. weights: Optional dask array of weights with the same shape as values. axis: Axis or axes along which to compute quantiles. If None, quantiles are computed over all axes. compression: T-Digest compression parameter. Higher values give more accurate results but use more memory. Default is 100. dtype: Data type for computation. Can be "float32", "float64", np.float32, np.float64, or None (defaults to float64). Returns: A TDigest instance that can be used to compute quantiles. Raises: ImportError: If dask is not installed. TypeError: If inputs are not dask arrays. ValueError: If values and weights have different shapes. Example: >>> import dask.array as da >>> import pyinterp.dask as dask_stats >>> values = da.random.random((10000,), chunks=1000) >>> digest = dask_stats.tdigest(values) >>> print(f"Median: {digest.quantile(0.5):.4f}") >>> print(f"Q25: {digest.quantile(0.25):.4f}") >>> print(f"Q75: {digest.quantile(0.75):.4f}") """ _check_dask_available() import dask.array as da # noqa: PLC0415 (to avoid import issues) values = _validate_dask_array(values, "values") if weights is not None: weights = _validate_dask_array(weights, "weights") _validate_shapes_match(values, weights) def _process_block( values_block: np.ndarray, weights_block: np.ndarray | None, axis: list[int] | None, compression: int, dtype: str | type | np.dtype | None, block_id: tuple[int, ...] | None = None, ) -> np.ndarray: """Process a single block and return TDigest as object array.""" digest = core.TDigest( values_block, weights=weights_block, axis=axis, compression=compression, dtype=dtype, ) result = np.empty((1,), dtype=object) result[0] = digest return result if weights is None: blocks = da.map_blocks( _process_block, values, None, axis, compression, dtype, dtype=object, drop_axis=list(range(values.ndim)), new_axis=0, ) else: blocks = da.map_blocks( _process_block, values, weights, axis, compression, dtype, dtype=object, drop_axis=list(range(values.ndim)), new_axis=0, ) # Compute all blocks and merge results = blocks.compute() # Merge all results using in-place addition merged = results[0] for item in results[1:]: merged += item return merged
[docs] def binning1d( x: dask.array.Array, z: dask.array.Array, binning: core.Binning1DHolder, weights: dask.array.Array | None = None, ) -> core.Binning1DHolder: """Accumulate values into 1D bins from a dask array. This function processes a dask array in parallel, binning values according to the x coordinates and accumulating statistics in each bin. Args: x: Dask array of x coordinates. z: Dask array of values to bin. binning: A Binning1D instance defining the bins. A copy is made internally, so the original is not modified. weights: Optional dask array of weights with the same shape as z. Returns: A new Binning1D instance with accumulated statistics. Raises: ImportError: If dask is not installed. TypeError: If inputs are not dask arrays. ValueError: If x and z have different shapes, or if weights shape doesn't match. Example: >>> import dask.array as da >>> import numpy as np >>> import pyinterp >>> import pyinterp.dask as dask_stats Create bins and data >>> axis = pyinterp.Axis(np.linspace(0, 10, 11)) >>> binning = pyinterp.Binning1D(axis) Create dask arrays >>> x = da.random.uniform(0, 10, size=(10000,), chunks=1000) >>> z = da.random.random((10000,), chunks=1000) Compute binned statistics >>> result = dask_stats.binning1d(x, z, binning) >>> print(result.mean()) """ _check_dask_available() import dask.array as da # noqa: PLC0415 (to avoid import issues) x = _validate_dask_array(x, "x") z = _validate_dask_array(z, "z") if x.shape != z.shape: msg = f"x and z must have the same shape, got {x.shape} and {z.shape}" raise ValueError(msg) if weights is not None: weights = _validate_dask_array(weights, "weights") if weights.shape != z.shape: msg = ( f"weights and z must have the same shape, " f"got {weights.shape} and {z.shape}" ) raise ValueError(msg) # Get axis and range from the binning instance for creating new instances axis = binning.x bin_range = binning.range() def _process_block( x_block: np.ndarray, z_block: np.ndarray, weights_block: np.ndarray | None, axis: core.Axis, bin_range: tuple[float, float] | None, block_id: tuple[int, ...] | None = None, ) -> np.ndarray: """Process a single block and return binning as object array.""" # Create a fresh binning instance for this block local_binning = core.Binning1D(copy.copy(axis), range=bin_range) weights = weights_block.ravel() if weights_block is not None else None local_binning.push( x_block.ravel(), z_block.ravel(), weights=weights, ) result = np.empty((1,), dtype=object) result[0] = local_binning return result if weights is None: blocks = da.map_blocks( _process_block, x, z, None, axis, bin_range, dtype=object, drop_axis=list(range(x.ndim)), new_axis=0, ) else: blocks = da.map_blocks( _process_block, x, z, weights, axis, bin_range, dtype=object, drop_axis=list(range(x.ndim)), new_axis=0, ) # Compute all blocks and merge results = blocks.compute() # Merge all results using in-place addition merged = results[0] for item in results[1:]: merged += item return merged
[docs] def binning2d( x: dask.array.Array, y: dask.array.Array, z: dask.array.Array, binning: core.Binning2DHolder, simple: bool = True, ) -> core.Binning2DHolder: """Accumulate values into 2D bins from dask arrays. This function processes dask arrays in parallel, binning values according to the x and y coordinates and accumulating statistics in each bin. Args: x: Dask array of x coordinates. y: Dask array of y coordinates. z: Dask array of values to bin. binning: A Binning2D instance defining the bins. A copy is made internally, so the original is not modified. simple: If True, use simple binning (nearest neighbor). If False, use linear interpolation to distribute values among neighboring bins. Default is True. Returns: A new Binning2D instance with accumulated statistics. Raises: ImportError: If dask is not installed. TypeError: If inputs are not dask arrays. ValueError: If x, y, and z have different shapes. Example: >>> import dask.array as da >>> import numpy as np >>> import pyinterp >>> import pyinterp.dask as dask_stats Create bins and data >>> x_axis = pyinterp.Axis(np.linspace(0, 10, 11)) >>> y_axis = pyinterp.Axis(np.linspace(0, 10, 11)) >>> binning = pyinterp.Binning2D(x_axis, y_axis) Create dask arrays >>> x = da.random.uniform(0, 10, size=(10000,), chunks=1000) >>> y = da.random.uniform(0, 10, size=(10000,), chunks=1000) >>> z = da.random.random((10000,), chunks=1000) Compute binned statistics >>> result = dask_stats.binning2d(x, y, z, binning) >>> print(result.mean()) """ _check_dask_available() import dask.array as da # noqa: PLC0415 (to avoid import issues) x = _validate_dask_array(x, "x") y = _validate_dask_array(y, "y") z = _validate_dask_array(z, "z") if x.shape != y.shape or x.shape != z.shape: msg = ( f"x, y, and z must have the same shape, " f"got {x.shape}, {y.shape}, and {z.shape}" ) raise ValueError(msg) # Get axes and spheroid from the binning instance x_axis = binning.x y_axis = binning.y spheroid = binning.spheroid def _process_block( x_block: np.ndarray, y_block: np.ndarray, z_block: np.ndarray, x_axis: core.Axis, y_axis: core.Axis, spheroid: core.geometry.geographic.Spheroid | None, simple: bool, block_id: tuple[int, ...] | None = None, ) -> np.ndarray: """Process a single block and return binning as object array.""" # Create a fresh binning instance for this block local_binning = core.Binning2D( copy.copy(x_axis), copy.copy(y_axis), spheroid=spheroid, ) local_binning.push( x_block.ravel(), y_block.ravel(), z_block.ravel(), simple=simple, ) result = np.empty((1,), dtype=object) result[0] = local_binning return result blocks = da.map_blocks( _process_block, x, y, z, x_axis, y_axis, spheroid, simple, dtype=object, drop_axis=list(range(x.ndim)), new_axis=0, ) # Compute all blocks and merge results = blocks.compute() # Merge all results using in-place addition merged = results[0] for item in results[1:]: merged += item return merged
[docs] def histogram2d( x: dask.array.Array, y: dask.array.Array, z: dask.array.Array, histogram: core.Histogram2DHolder, ) -> core.Histogram2DHolder: """Accumulate values into a 2D histogram from dask arrays. This function processes dask arrays in parallel, accumulating values into a 2D histogram based on x and y coordinates. Args: x: Dask array of x coordinates. y: Dask array of y coordinates. z: Dask array of values to accumulate. histogram: A Histogram2D instance defining the grid. A copy is made internally, so the original is not modified. Returns: A new Histogram2D instance with accumulated values. Raises: ImportError: If dask is not installed. TypeError: If inputs are not dask arrays. ValueError: If x, y, and z have different shapes. Example: >>> import dask.array as da >>> import numpy as np >>> import pyinterp >>> import pyinterp.dask as dask_stats Create histogram and data >>> x_axis = pyinterp.Axis(np.linspace(0, 10, 11)) >>> y_axis = pyinterp.Axis(np.linspace(0, 10, 11)) >>> hist = pyinterp.Histogram2D(x_axis, y_axis) Create dask arrays >>> x = da.random.uniform(0, 10, size=(10000,), chunks=1000) >>> y = da.random.uniform(0, 10, size=(10000,), chunks=1000) >>> z = da.random.random((10000,), chunks=1000) Compute histogram >>> result = dask_stats.histogram2d(x, y, z, hist) >>> print(result.mean()) >>> print(result.quantile(0.5)) """ _check_dask_available() import dask.array as da # noqa: PLC0415 (to avoid import issues) x = _validate_dask_array(x, "x") y = _validate_dask_array(y, "y") z = _validate_dask_array(z, "z") if x.shape != y.shape or x.shape != z.shape: msg = ( f"x, y, and z must have the same shape, " f"got {x.shape}, {y.shape}, and {z.shape}" ) raise ValueError(msg) # Get axes from the histogram instance x_axis = histogram.x y_axis = histogram.y def _process_block( x_block: np.ndarray, y_block: np.ndarray, z_block: np.ndarray, x_axis: core.Axis, y_axis: core.Axis, block_id: tuple[int, ...] | None = None, ) -> np.ndarray: """Process a single block and return histogram as object array.""" # Create a fresh histogram instance for this block local_hist = core.Histogram2D( copy.copy(x_axis), copy.copy(y_axis), ) local_hist.push( x_block.ravel(), y_block.ravel(), z_block.ravel(), ) result = np.empty((1,), dtype=object) result[0] = local_hist return result blocks = da.map_blocks( _process_block, x, y, z, x_axis, y_axis, dtype=object, drop_axis=list(range(x.ndim)), new_axis=0, ) # Compute all blocks and merge results = blocks.compute() # Merge all results using in-place addition merged = results[0] for item in results[1:]: merged += item return merged