Source code for xrspatial.interpolate._kriging

"""Ordinary Kriging interpolation."""

from __future__ import annotations

import warnings

import numpy as np
import xarray as xr

from xrspatial.utils import (ArrayTypeFunctionMapping, _dask_task_name_kwargs, _validate_raster,
                             _validate_scalar)

from ._validation import extract_grid_coords, validate_points
from ._vector import resolve_xyz

try:
    import cupy
except ImportError:
    cupy = None

try:
    import dask.array as da
except ImportError:
    da = None


def _get_xp(arr):
    """Return the array module (numpy or cupy) for *arr*."""
    if cupy is not None and isinstance(arr, cupy.ndarray):
        return cupy
    return np


# ---------------------------------------------------------------------------
# Variogram models
# ---------------------------------------------------------------------------

def _spherical(h, c0, c, a):
    xp = _get_xp(h)
    return xp.where(
        h < a,
        c0 + c * (1.5 * h / a - 0.5 * (h / a) ** 3),
        c0 + c,
    )


def _exponential(h, c0, c, a):
    xp = _get_xp(h)
    return c0 + c * (1.0 - xp.exp(-3.0 * h / a))


def _gaussian(h, c0, c, a):
    xp = _get_xp(h)
    return c0 + c * (1.0 - xp.exp(-3.0 * (h / a) ** 2))


_VARIOGRAM_MODELS = {
    'spherical': _spherical,
    'exponential': _exponential,
    'gaussian': _gaussian,
}


# ---------------------------------------------------------------------------
# Experimental variogram
# ---------------------------------------------------------------------------

def _experimental_variogram(x, y, z, nlags):
    """Compute binned experimental variogram from point data."""
    n = len(x)
    i_idx, j_idx = np.triu_indices(n, k=1)
    dx = x[i_idx] - x[j_idx]
    dy = y[i_idx] - y[j_idx]
    dists = np.sqrt(dx ** 2 + dy ** 2)
    semivar = 0.5 * (z[i_idx] - z[j_idx]) ** 2

    if dists.size == 0:
        # Fewer than two points: no pairs, so no spatial structure to bin.
        return np.array([]), np.array([])

    max_dist = dists.max() / 2.0
    if max_dist <= 0:
        return np.array([]), np.array([])

    bins = np.linspace(0, max_dist, nlags + 1)
    lag_centers = 0.5 * (bins[:-1] + bins[1:])
    lag_sv = np.zeros(nlags, dtype=np.float64)
    lag_count = np.zeros(nlags, dtype=np.int64)

    bin_idx = np.digitize(dists, bins) - 1
    for k in range(nlags):
        mask = bin_idx == k
        if mask.any():
            lag_sv[k] = semivar[mask].mean()
            lag_count[k] = mask.sum()

    valid = lag_count > 0
    return lag_centers[valid], lag_sv[valid]


# ---------------------------------------------------------------------------
# Variogram fitting
# ---------------------------------------------------------------------------

def _fit_variogram(lag_h, lag_sv, model_func):
    """Fit variogram model parameters via curve_fit.

    Returns (c0, c, a).
    """
    from scipy.optimize import curve_fit

    c0_init = float(lag_sv[0]) if len(lag_sv) > 0 else 0.0
    c_init = float(lag_sv.max() - c0_init) if len(lag_sv) > 0 else 1.0
    a_init = float(lag_h[-1]) if len(lag_h) > 0 else 1.0

    c_init = max(c_init, 1e-10)
    a_init = max(a_init, 1e-10)

    try:
        popt, _ = curve_fit(
            model_func, lag_h, lag_sv,
            p0=[c0_init, c_init, a_init],
            bounds=([0, 0, 1e-12], [np.inf, np.inf, np.inf]),
            maxfev=5000,
        )
    except (RuntimeError, ValueError):
        popt = np.array([c0_init, c_init, a_init])

    return popt


# ---------------------------------------------------------------------------
# Kriging matrix
# ---------------------------------------------------------------------------

def _build_kriging_matrix(x, y, vario_func):
    """Build the (N+1)x(N+1) ordinary-kriging matrix and return K_inv."""
    n = len(x)
    dx = x[:, np.newaxis] - x[np.newaxis, :]
    dy = y[:, np.newaxis] - y[np.newaxis, :]
    D = np.sqrt(dx ** 2 + dy ** 2)

    K = np.zeros((n + 1, n + 1), dtype=np.float64)
    G = vario_func(D)
    # The semivariogram has gamma(0) = 0 by definition; vario_func(0)
    # returns the nugget c0, which is the one-sided limit as h -> 0+,
    # not the value at h = 0.  Force the diagonal to 0 so a non-zero
    # nugget is not placed on the matrix diagonal (which would force
    # exact interpolation and bias the kriging variance downward).
    np.fill_diagonal(G, 0.0)
    K[:n, :n] = G
    K[:n, n] = 1.0
    K[n, :n] = 1.0

    try:
        K_inv = np.linalg.inv(K)
    except np.linalg.LinAlgError:
        K[:n, :n] += 1e-10 * np.eye(n)
        try:
            K_inv = np.linalg.inv(K)
        except np.linalg.LinAlgError:
            warnings.warn(
                "kriging(): kriging matrix is singular even after "
                "regularisation; predictions will be NaN."
            )
            return None

    return K_inv


# ---------------------------------------------------------------------------
# Numpy prediction
# ---------------------------------------------------------------------------

def _dual_z_aug(xp, z_pts):
    """Return ``[z, 0]`` for the dual kriging form on array module *xp*."""
    z_aug = xp.zeros(len(z_pts) + 1, dtype=np.float64)
    z_aug[:len(z_pts)] = z_pts
    return z_aug


def _kriging_predict(x_pts, y_pts, z_pts, x_grid, y_grid,
                     vario_func, K_inv, return_variance, dual_w=None):
    """Vectorised kriging prediction for a grid chunk.

    ``dual_w`` is an optional precomputed dual weight vector
    ``K_inv @ [z, 0]`` used by the no-variance path; the dask backend
    passes it in so the O(N^2) matvec runs once instead of per chunk.
    """
    n = len(x_pts)
    gx, gy = np.meshgrid(x_grid, y_grid)
    gx_flat = gx.ravel()
    gy_flat = gy.ravel()

    dx = gx_flat[:, np.newaxis] - x_pts[np.newaxis, :]
    dy = gy_flat[:, np.newaxis] - y_pts[np.newaxis, :]
    dists = np.sqrt(dx ** 2 + dy ** 2)

    k0 = np.empty((len(gx_flat), n + 1), dtype=np.float64)
    k0[:, :n] = vario_func(dists)
    k0[:, n] = 1.0

    variance = None
    if return_variance:
        # w = K_inv @ k0 for each pixel (K_inv is symmetric)
        w = k0 @ K_inv
        prediction = (w[:, :n] * z_pts[np.newaxis, :]).sum(axis=1)
        variance = np.sum(w * k0, axis=1)
        variance = variance.reshape(len(y_grid), len(x_grid))
    else:
        # Prediction only: the weights are needed solely to dot against
        # z, so fold K_inv into z up front (dual kriging form).
        # k0 @ (K_inv @ z_aug) equals (k0 @ K_inv)[:, :n] @ z because
        # z_aug[n] == 0, and it replaces the
        # (grid_pixels, N+1) x (N+1, N+1) matmul and the
        # (grid_pixels, N+1) `w` temporary with a single matvec.
        if dual_w is None:
            dual_w = K_inv @ _dual_z_aug(np, z_pts)
        prediction = k0 @ dual_w
    prediction = prediction.reshape(len(y_grid), len(x_grid))

    return prediction, variance


# ---------------------------------------------------------------------------
# Numpy backend wrapper
# ---------------------------------------------------------------------------

def _kriging_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
                   vario_func, K_inv, return_variance, template_data):
    return _kriging_predict(x_pts, y_pts, z_pts, x_grid, y_grid,
                            vario_func, K_inv, return_variance)


# ---------------------------------------------------------------------------
# Dask + numpy backend
# ---------------------------------------------------------------------------

def _kriging_dask_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
                        vario_func, K_inv, return_variance, template_data):

    if not return_variance:
        # The dual weight vector is chunk-invariant; compute it once
        # here rather than once per chunk.
        dual_w = K_inv @ _dual_z_aug(np, z_pts)

        def _chunk_pred(block, block_info=None):
            if block_info is None:
                return block
            loc = block_info[0]['array-location']
            y_sl = y_grid[loc[0][0]:loc[0][1]]
            x_sl = x_grid[loc[1][0]:loc[1][1]]
            pred, _ = _kriging_predict(x_pts, y_pts, z_pts, x_sl, y_sl,
                                       vario_func, K_inv, False,
                                       dual_w=dual_w)
            return pred

        prediction = da.map_blocks(_chunk_pred, template_data,
                                   dtype=np.float64,
                                   **_dask_task_name_kwargs('xrspatial.kriging'))
        return prediction, None

    # Prediction and variance fall out of the same per-chunk pipeline
    # (dists, k0, weights), so compute both in one map_blocks pass and
    # slice the two channels.  Two separate passes would run that
    # pipeline twice per chunk.
    def _chunk_both(block, block_info=None):
        if block_info is None:
            return np.stack([block, block])
        loc = block_info[0]['array-location']
        y_sl = y_grid[loc[0][0]:loc[0][1]]
        x_sl = x_grid[loc[1][0]:loc[1][1]]
        pred, var = _kriging_predict(x_pts, y_pts, z_pts, x_sl, y_sl,
                                     vario_func, K_inv, True)
        return np.stack([pred, var])

    stacked = da.map_blocks(
        _chunk_both, template_data, dtype=np.float64,
        new_axis=0, chunks=((2,),) + template_data.chunks,
        **_dask_task_name_kwargs('xrspatial.kriging_variance'),
    )
    return stacked[0], stacked[1]


# ---------------------------------------------------------------------------
# CuPy prediction
# ---------------------------------------------------------------------------

def _kriging_predict_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
                          vario_func, K_inv, return_variance,
                          dual_w=None):
    """Vectorised kriging prediction on GPU via CuPy.

    The point arrays and ``K_inv`` may arrive either on the host or
    already on the device; ``cupy.asarray`` is a no-op for the latter.
    ``dual_w`` is the optional precomputed device-resident dual weight
    vector for the no-variance path (see ``_kriging_predict``).
    """
    n = len(x_pts)

    x_gpu = cupy.asarray(x_pts)
    y_gpu = cupy.asarray(y_pts)
    z_gpu = cupy.asarray(z_pts)
    xg_gpu = cupy.asarray(x_grid)
    yg_gpu = cupy.asarray(y_grid)
    K_inv_gpu = cupy.asarray(K_inv)

    gx, gy = cupy.meshgrid(xg_gpu, yg_gpu)
    gx_flat = gx.ravel()
    gy_flat = gy.ravel()

    dx = gx_flat[:, None] - x_gpu[None, :]
    dy = gy_flat[:, None] - y_gpu[None, :]
    dists = cupy.sqrt(dx ** 2 + dy ** 2)

    k0 = cupy.empty((len(gx_flat), n + 1), dtype=np.float64)
    k0[:, :n] = vario_func(dists)
    k0[:, n] = 1.0

    variance = None
    if return_variance:
        w = k0 @ K_inv_gpu
        prediction = (w[:, :n] * z_gpu[None, :]).sum(axis=1)
        variance = cupy.sum(w * k0, axis=1)
        variance = variance.reshape(len(y_grid), len(x_grid))
    else:
        # Dual kriging form; see _kriging_predict for the derivation.
        if dual_w is None:
            dual_w = K_inv_gpu @ _dual_z_aug(cupy, z_gpu)
        prediction = k0 @ dual_w
    prediction = prediction.reshape(len(y_grid), len(x_grid))

    return prediction, variance


# ---------------------------------------------------------------------------
# CuPy backend wrapper
# ---------------------------------------------------------------------------

def _kriging_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
                  vario_func, K_inv, return_variance, template_data):
    return _kriging_predict_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
                                 vario_func, K_inv, return_variance)


# ---------------------------------------------------------------------------
# Dask + CuPy backend
# ---------------------------------------------------------------------------

def _kriging_dask_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
                       vario_func, K_inv, return_variance, template_data):

    # The point arrays and K_inv are the same for every chunk, so
    # upload them to the device once instead of once per chunk.
    # _kriging_predict_cupy passes them through cupy.asarray, which is
    # a no-op for device-resident arrays.  K_inv is (N+1) x (N+1), so
    # re-uploading it per chunk is the expensive part for large N.
    # Under the threaded/synchronous scheduler the per-chunk closure
    # shares these device buffers by reference; a distributed scheduler
    # would re-serialise them per task, which is no worse than the
    # previous per-chunk upload.
    x_gpu = cupy.asarray(x_pts)
    y_gpu = cupy.asarray(y_pts)
    z_gpu = cupy.asarray(z_pts)
    K_inv_gpu = cupy.asarray(K_inv)

    if not return_variance:
        # The dual weight vector is chunk-invariant; compute it once
        # here rather than once per chunk.
        dual_w = K_inv_gpu @ _dual_z_aug(cupy, z_gpu)

        def _chunk_pred(block, block_info=None):
            if block_info is None:
                return block
            loc = block_info[0]['array-location']
            y_sl = y_grid[loc[0][0]:loc[0][1]]
            x_sl = x_grid[loc[1][0]:loc[1][1]]
            pred, _ = _kriging_predict_cupy(x_gpu, y_gpu, z_gpu, x_sl, y_sl,
                                            vario_func, K_inv_gpu, False,
                                            dual_w=dual_w)
            return pred

        prediction = da.map_blocks(
            _chunk_pred, template_data, dtype=np.float64,
            meta=cupy.array((), dtype=np.float64),
            **_dask_task_name_kwargs('xrspatial.kriging'),
        )
        return prediction, None

    # Single pass for prediction and variance; see _kriging_dask_numpy.
    def _chunk_both(block, block_info=None):
        if block_info is None:
            return cupy.stack([block, block])
        loc = block_info[0]['array-location']
        y_sl = y_grid[loc[0][0]:loc[0][1]]
        x_sl = x_grid[loc[1][0]:loc[1][1]]
        pred, var = _kriging_predict_cupy(x_gpu, y_gpu, z_gpu, x_sl, y_sl,
                                          vario_func, K_inv_gpu, True)
        return cupy.stack([pred, var])

    stacked = da.map_blocks(
        _chunk_both, template_data, dtype=np.float64,
        new_axis=0, chunks=((2,),) + template_data.chunks,
        meta=cupy.array((), dtype=np.float64),
        **_dask_task_name_kwargs('xrspatial.kriging_variance'),
    )
    return stacked[0], stacked[1]


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def _check_kriging_memory(n_points, grid_pixels, is_dask=False):
    """Raise MemoryError if kriging() would exceed available memory.

    Three allocations dominate kriging memory use:

    * Experimental variogram pair arrays.  ``np.triu_indices`` produces
      two int64 index arrays of length ``N*(N-1)/2``, plus float64
      ``dists`` and ``semivar`` of the same length.  Roughly
      ``4 * N*(N-1)/2 * 8`` bytes.
    * Kriging matrix.  ``K`` and ``K_inv`` are both ``(N+1, N+1)``
      float64, plus an N x N intermediate distance matrix.  About
      ``3 * (N+1)**2 * 8`` bytes.
    * Prediction ``k0`` matrix of shape ``(grid_pixels, N+1)`` float64,
      plus matching ``dists`` and ``w`` of similar size.  About
      ``3 * grid_pixels * (N+1) * 8`` bytes.

    Worst case is the maximum of these three.  The variogram and matrix
    builds run sequentially, and ``k0`` is built later, so peak usage
    is bounded by the largest single allocation.

    When ``is_dask`` is True the prediction ``k0`` matrix is built one
    chunk at a time by ``map_blocks``, so its peak size scales with the
    chunk rather than ``grid_pixels``.  ``grid_pixels`` is not a valid
    bound for that path, so the ``k0`` term is dropped.  The variogram
    and matrix terms are point-based and materialised on the host
    regardless of backend, so they still apply.
    """
    n = int(n_points)
    g = int(grid_pixels)

    pair_bytes = 4 * (n * (n - 1) // 2) * 8 if n > 1 else 0
    matrix_bytes = 3 * (n + 1) * (n + 1) * 8
    k0_bytes = 0 if is_dask else 3 * g * (n + 1) * 8

    estimate = max(pair_bytes, matrix_bytes, k0_bytes)

    try:
        from xrspatial.zonal import _available_memory_bytes
        avail = _available_memory_bytes()
    except ImportError:
        avail = 2 * 1024 ** 3

    if estimate > 0.8 * avail:
        if estimate == k0_bytes:
            culprit = (
                f"prediction matrix of shape ({g}, {n + 1}) "
                f"(grid_pixels x N+1)"
            )
            advice = (
                "Reduce the template grid size or the number of input "
                "points, or use a chunked dask-backed template."
            )
        elif estimate == matrix_bytes:
            culprit = f"kriging matrix of shape ({n + 1}, {n + 1})"
            advice = "Reduce the number of input points."
        else:
            culprit = (
                f"variogram pair arrays of length {n * (n - 1) // 2} "
                f"(N*(N-1)/2 for N={n})"
            )
            advice = "Reduce the number of input points."

        raise MemoryError(
            f"kriging() needs ~{estimate / 1e9:.1f} GB to allocate the "
            f"{culprit}, but only ~{avail / 1e9:.1f} GB is available. "
            f"{advice}"
        )


[docs] def kriging(x, y=None, z=None, template=None, variogram_model='spherical', nlags=15, return_variance=False, name='kriging', *, column=None): """Ordinary Kriging interpolation. Parameters ---------- x, y, z : array-like Coordinates and values of scattered input points. Alternatively, pass a GeoDataFrame of Point geometries as the first argument and leave *y*/*z* unset; *template* and *column* must then be keywords. template : xr.DataArray 2-D DataArray whose grid defines the output raster. variogram_model : str, default 'spherical' Variogram model: ``'spherical'``, ``'exponential'``, or ``'gaussian'``. nlags : int, default 15 Number of lag bins for the experimental variogram. return_variance : bool, default False If True, return ``(prediction, variance)`` tuple. name : str, default 'kriging' Name of the output DataArray. column : str, optional When the first argument is a GeoDataFrame, the column whose values are interpolated. Defaults to the first numeric column. Returns ------- xr.DataArray or tuple of xr.DataArray Prediction raster, or ``(prediction, variance)`` if *return_variance* is True. Raises ------ MemoryError If the worst-case allocation (variogram pair arrays, kriging matrix, or prediction matrix) would exceed 80% of available memory. """ x, y, z = resolve_xyz(x, y, z, column=column, func_name='kriging') _validate_raster(template, func_name='kriging', name='template') x_arr, y_arr, z_arr = validate_points(x, y, z, func_name='kriging') if variogram_model not in _VARIOGRAM_MODELS: raise ValueError( f"kriging(): variogram_model must be one of " f"{sorted(_VARIOGRAM_MODELS)}, got {variogram_model!r}" ) _validate_scalar(nlags, func_name='kriging', name='nlags', dtype=int, min_val=1) # Memory guard. Runs after input validation so we know N and the # template grid size, but before any large allocation. grid_pixels = int(np.prod(template.shape)) is_dask = da is not None and isinstance(template.data, da.Array) _check_kriging_memory(len(x_arr), grid_pixels, is_dask=is_dask) # Experimental variogram lag_h, lag_sv = _experimental_variogram(x_arr, y_arr, z_arr, nlags) if len(lag_h) < 3: warnings.warn( "kriging(): fewer than 3 non-empty lag bins; variogram fit " "may be unreliable. Consider using more points or fewer lags." ) if len(lag_h) == 0: lag_h = np.array([1.0]) lag_sv = np.array([1.0]) # Fit variogram model model_func = _VARIOGRAM_MODELS[variogram_model] params = _fit_variogram(lag_h, lag_sv, model_func) def vario_func(h): return model_func(h, *params) # Build kriging matrix K_inv = _build_kriging_matrix(x_arr, y_arr, vario_func) x_grid, y_grid = extract_grid_coords(template, func_name='kriging') if K_inv is None: # Build the all-NaN fallback on the same backend as the template: # lazy per-chunk for dask templates, and np.full_like dispatches to # cupy via __array_function__ for cupy-backed blocks. if da is not None and isinstance(template.data, da.Array): pred = template.data.map_blocks( lambda block: np.full_like(block, np.nan, dtype=np.float64), dtype=np.float64, ) else: pred = np.full_like(template.data, np.nan, dtype=np.float64) prediction = xr.DataArray( pred, name=name, coords=template.coords, dims=template.dims, attrs=template.attrs, ) if return_variance: variance = prediction.copy() variance.name = f'{name}_variance' return prediction, variance return prediction mapper = ArrayTypeFunctionMapping( numpy_func=_kriging_numpy, cupy_func=_kriging_cupy, dask_func=_kriging_dask_numpy, dask_cupy_func=_kriging_dask_cupy, ) pred_arr, var_arr = mapper(template)( x_arr, y_arr, z_arr, x_grid, y_grid, vario_func, K_inv, return_variance, template.data, ) prediction = xr.DataArray( pred_arr, name=name, coords=template.coords, dims=template.dims, attrs=template.attrs, ) if return_variance: variance = xr.DataArray( var_arr, name=f'{name}_variance', coords=template.coords, dims=template.dims, attrs=template.attrs, ) return prediction, variance return prediction