"""Ordinary Kriging interpolation."""
from __future__ import annotations
import warnings
import numpy as np
import xarray as xr
from xrspatial.utils import (
ArrayTypeFunctionMapping,
_validate_raster,
_validate_scalar,
)
from ._validation import extract_grid_coords, validate_points
try:
import cupy
except ImportError:
cupy = None
try:
import dask.array as da
except ImportError:
da = None
def _get_xp(arr):
"""Return the array module (numpy or cupy) for *arr*."""
if cupy is not None and isinstance(arr, cupy.ndarray):
return cupy
return np
# ---------------------------------------------------------------------------
# Variogram models
# ---------------------------------------------------------------------------
def _spherical(h, c0, c, a):
xp = _get_xp(h)
return xp.where(
h < a,
c0 + c * (1.5 * h / a - 0.5 * (h / a) ** 3),
c0 + c,
)
def _exponential(h, c0, c, a):
xp = _get_xp(h)
return c0 + c * (1.0 - xp.exp(-3.0 * h / a))
def _gaussian(h, c0, c, a):
xp = _get_xp(h)
return c0 + c * (1.0 - xp.exp(-3.0 * (h / a) ** 2))
_VARIOGRAM_MODELS = {
'spherical': _spherical,
'exponential': _exponential,
'gaussian': _gaussian,
}
# ---------------------------------------------------------------------------
# Experimental variogram
# ---------------------------------------------------------------------------
def _experimental_variogram(x, y, z, nlags):
"""Compute binned experimental variogram from point data."""
n = len(x)
i_idx, j_idx = np.triu_indices(n, k=1)
dx = x[i_idx] - x[j_idx]
dy = y[i_idx] - y[j_idx]
dists = np.sqrt(dx ** 2 + dy ** 2)
semivar = 0.5 * (z[i_idx] - z[j_idx]) ** 2
max_dist = dists.max() / 2.0
if max_dist <= 0:
return np.array([]), np.array([])
bins = np.linspace(0, max_dist, nlags + 1)
lag_centers = 0.5 * (bins[:-1] + bins[1:])
lag_sv = np.zeros(nlags, dtype=np.float64)
lag_count = np.zeros(nlags, dtype=np.int64)
bin_idx = np.digitize(dists, bins) - 1
for k in range(nlags):
mask = bin_idx == k
if mask.any():
lag_sv[k] = semivar[mask].mean()
lag_count[k] = mask.sum()
valid = lag_count > 0
return lag_centers[valid], lag_sv[valid]
# ---------------------------------------------------------------------------
# Variogram fitting
# ---------------------------------------------------------------------------
def _fit_variogram(lag_h, lag_sv, model_func):
"""Fit variogram model parameters via curve_fit.
Returns (c0, c, a).
"""
from scipy.optimize import curve_fit
c0_init = float(lag_sv[0]) if len(lag_sv) > 0 else 0.0
c_init = float(lag_sv.max() - c0_init) if len(lag_sv) > 0 else 1.0
a_init = float(lag_h[-1]) if len(lag_h) > 0 else 1.0
c_init = max(c_init, 1e-10)
a_init = max(a_init, 1e-10)
try:
popt, _ = curve_fit(
model_func, lag_h, lag_sv,
p0=[c0_init, c_init, a_init],
bounds=([0, 0, 1e-12], [np.inf, np.inf, np.inf]),
maxfev=5000,
)
except (RuntimeError, ValueError):
popt = np.array([c0_init, c_init, a_init])
return popt
# ---------------------------------------------------------------------------
# Kriging matrix
# ---------------------------------------------------------------------------
def _build_kriging_matrix(x, y, vario_func):
"""Build the (N+1)x(N+1) ordinary-kriging matrix and return K_inv."""
n = len(x)
dx = x[:, np.newaxis] - x[np.newaxis, :]
dy = y[:, np.newaxis] - y[np.newaxis, :]
D = np.sqrt(dx ** 2 + dy ** 2)
K = np.zeros((n + 1, n + 1), dtype=np.float64)
K[:n, :n] = vario_func(D)
K[:n, n] = 1.0
K[n, :n] = 1.0
try:
K_inv = np.linalg.inv(K)
except np.linalg.LinAlgError:
K[:n, :n] += 1e-10 * np.eye(n)
try:
K_inv = np.linalg.inv(K)
except np.linalg.LinAlgError:
warnings.warn(
"kriging(): kriging matrix is singular even after "
"regularisation; predictions will be NaN."
)
return None
return K_inv
# ---------------------------------------------------------------------------
# Numpy prediction
# ---------------------------------------------------------------------------
def _kriging_predict(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance):
"""Vectorised kriging prediction for a grid chunk."""
n = len(x_pts)
gx, gy = np.meshgrid(x_grid, y_grid)
gx_flat = gx.ravel()
gy_flat = gy.ravel()
dx = gx_flat[:, np.newaxis] - x_pts[np.newaxis, :]
dy = gy_flat[:, np.newaxis] - y_pts[np.newaxis, :]
dists = np.sqrt(dx ** 2 + dy ** 2)
k0 = np.empty((len(gx_flat), n + 1), dtype=np.float64)
k0[:, :n] = vario_func(dists)
k0[:, n] = 1.0
# w = K_inv @ k0 for each pixel (K_inv is symmetric)
w = k0 @ K_inv
prediction = (w[:, :n] * z_pts[np.newaxis, :]).sum(axis=1)
prediction = prediction.reshape(len(y_grid), len(x_grid))
variance = None
if return_variance:
variance = np.sum(w * k0, axis=1)
variance = variance.reshape(len(y_grid), len(x_grid))
return prediction, variance
# ---------------------------------------------------------------------------
# Numpy backend wrapper
# ---------------------------------------------------------------------------
def _kriging_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance, template_data):
return _kriging_predict(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance)
# ---------------------------------------------------------------------------
# Dask + numpy backend
# ---------------------------------------------------------------------------
def _kriging_dask_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance, template_data):
def _chunk_pred(block, block_info=None):
if block_info is None:
return block
loc = block_info[0]['array-location']
y_sl = y_grid[loc[0][0]:loc[0][1]]
x_sl = x_grid[loc[1][0]:loc[1][1]]
pred, _ = _kriging_predict(x_pts, y_pts, z_pts, x_sl, y_sl,
vario_func, K_inv, False)
return pred
prediction = da.map_blocks(_chunk_pred, template_data, dtype=np.float64)
variance = None
if return_variance:
def _chunk_var(block, block_info=None):
if block_info is None:
return block
loc = block_info[0]['array-location']
y_sl = y_grid[loc[0][0]:loc[0][1]]
x_sl = x_grid[loc[1][0]:loc[1][1]]
_, var = _kriging_predict(x_pts, y_pts, z_pts, x_sl, y_sl,
vario_func, K_inv, True)
return var
variance = da.map_blocks(_chunk_var, template_data, dtype=np.float64)
return prediction, variance
# ---------------------------------------------------------------------------
# CuPy prediction
# ---------------------------------------------------------------------------
def _kriging_predict_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance):
"""Vectorised kriging prediction on GPU via CuPy."""
n = len(x_pts)
x_gpu = cupy.asarray(x_pts)
y_gpu = cupy.asarray(y_pts)
z_gpu = cupy.asarray(z_pts)
xg_gpu = cupy.asarray(x_grid)
yg_gpu = cupy.asarray(y_grid)
K_inv_gpu = cupy.asarray(K_inv)
gx, gy = cupy.meshgrid(xg_gpu, yg_gpu)
gx_flat = gx.ravel()
gy_flat = gy.ravel()
dx = gx_flat[:, None] - x_gpu[None, :]
dy = gy_flat[:, None] - y_gpu[None, :]
dists = cupy.sqrt(dx ** 2 + dy ** 2)
k0 = cupy.empty((len(gx_flat), n + 1), dtype=np.float64)
k0[:, :n] = vario_func(dists)
k0[:, n] = 1.0
w = k0 @ K_inv_gpu
prediction = (w[:, :n] * z_gpu[None, :]).sum(axis=1)
prediction = prediction.reshape(len(y_grid), len(x_grid))
variance = None
if return_variance:
variance = cupy.sum(w * k0, axis=1)
variance = variance.reshape(len(y_grid), len(x_grid))
return prediction, variance
# ---------------------------------------------------------------------------
# CuPy backend wrapper
# ---------------------------------------------------------------------------
def _kriging_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance, template_data):
return _kriging_predict_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance)
# ---------------------------------------------------------------------------
# Dask + CuPy backend
# ---------------------------------------------------------------------------
def _kriging_dask_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
vario_func, K_inv, return_variance, template_data):
def _chunk_pred(block, block_info=None):
if block_info is None:
return block
loc = block_info[0]['array-location']
y_sl = y_grid[loc[0][0]:loc[0][1]]
x_sl = x_grid[loc[1][0]:loc[1][1]]
pred, _ = _kriging_predict_cupy(x_pts, y_pts, z_pts, x_sl, y_sl,
vario_func, K_inv, False)
return pred
prediction = da.map_blocks(
_chunk_pred, template_data, dtype=np.float64,
meta=cupy.array((), dtype=np.float64),
)
variance = None
if return_variance:
def _chunk_var(block, block_info=None):
if block_info is None:
return block
loc = block_info[0]['array-location']
y_sl = y_grid[loc[0][0]:loc[0][1]]
x_sl = x_grid[loc[1][0]:loc[1][1]]
_, var = _kriging_predict_cupy(x_pts, y_pts, z_pts, x_sl, y_sl,
vario_func, K_inv, True)
return var
variance = da.map_blocks(
_chunk_var, template_data, dtype=np.float64,
meta=cupy.array((), dtype=np.float64),
)
return prediction, variance
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def _check_kriging_memory(n_points, grid_pixels):
"""Raise MemoryError if kriging() would exceed available memory.
Three allocations dominate kriging memory use:
* Experimental variogram pair arrays. ``np.triu_indices`` produces
two int64 index arrays of length ``N*(N-1)/2``, plus float64
``dists`` and ``semivar`` of the same length. Roughly
``4 * N*(N-1)/2 * 8`` bytes.
* Kriging matrix. ``K`` and ``K_inv`` are both ``(N+1, N+1)``
float64, plus an N x N intermediate distance matrix. About
``3 * (N+1)**2 * 8`` bytes.
* Prediction ``k0`` matrix of shape ``(grid_pixels, N+1)`` float64,
plus matching ``dists`` and ``w`` of similar size. About
``3 * grid_pixels * (N+1) * 8`` bytes.
Worst case is the maximum of these three. The variogram and matrix
builds run sequentially, and ``k0`` is built later, so peak usage
is bounded by the largest single allocation.
"""
n = int(n_points)
g = int(grid_pixels)
pair_bytes = 4 * (n * (n - 1) // 2) * 8 if n > 1 else 0
matrix_bytes = 3 * (n + 1) * (n + 1) * 8
k0_bytes = 3 * g * (n + 1) * 8
estimate = max(pair_bytes, matrix_bytes, k0_bytes)
try:
from xrspatial.zonal import _available_memory_bytes
avail = _available_memory_bytes()
except ImportError:
avail = 2 * 1024 ** 3
if estimate > 0.8 * avail:
if estimate == k0_bytes:
culprit = (
f"prediction matrix of shape ({g}, {n + 1}) "
f"(grid_pixels x N+1)"
)
advice = (
"Reduce the template grid size or the number of input "
"points, or use a chunked dask-backed template."
)
elif estimate == matrix_bytes:
culprit = f"kriging matrix of shape ({n + 1}, {n + 1})"
advice = "Reduce the number of input points."
else:
culprit = (
f"variogram pair arrays of length {n * (n - 1) // 2} "
f"(N*(N-1)/2 for N={n})"
)
advice = "Reduce the number of input points."
raise MemoryError(
f"kriging() needs ~{estimate / 1e9:.1f} GB to allocate the "
f"{culprit}, but only ~{avail / 1e9:.1f} GB is available. "
f"{advice}"
)
[docs]
def kriging(x, y, z, template, variogram_model='spherical', nlags=15,
return_variance=False, name='kriging'):
"""Ordinary Kriging interpolation.
Parameters
----------
x, y, z : array-like
Coordinates and values of scattered input points.
template : xr.DataArray
2-D DataArray whose grid defines the output raster.
variogram_model : str, default 'spherical'
Variogram model: ``'spherical'``, ``'exponential'``, or
``'gaussian'``.
nlags : int, default 15
Number of lag bins for the experimental variogram.
return_variance : bool, default False
If True, return ``(prediction, variance)`` tuple.
name : str, default 'kriging'
Name of the output DataArray.
Returns
-------
xr.DataArray or tuple of xr.DataArray
Prediction raster, or ``(prediction, variance)`` if
*return_variance* is True.
Raises
------
MemoryError
If the worst-case allocation (variogram pair arrays, kriging
matrix, or prediction matrix) would exceed 80% of available
memory.
"""
_validate_raster(template, func_name='kriging', name='template')
x_arr, y_arr, z_arr = validate_points(x, y, z, func_name='kriging')
if variogram_model not in _VARIOGRAM_MODELS:
raise ValueError(
f"kriging(): variogram_model must be one of "
f"{sorted(_VARIOGRAM_MODELS)}, got {variogram_model!r}"
)
_validate_scalar(nlags, func_name='kriging', name='nlags',
dtype=int, min_val=1)
# Memory guard. Runs after input validation so we know N and the
# template grid size, but before any large allocation.
grid_pixels = int(np.prod(template.shape))
_check_kriging_memory(len(x_arr), grid_pixels)
# Experimental variogram
lag_h, lag_sv = _experimental_variogram(x_arr, y_arr, z_arr, nlags)
if len(lag_h) < 3:
warnings.warn(
"kriging(): fewer than 3 non-empty lag bins; variogram fit "
"may be unreliable. Consider using more points or fewer lags."
)
if len(lag_h) == 0:
lag_h = np.array([1.0])
lag_sv = np.array([1.0])
# Fit variogram model
model_func = _VARIOGRAM_MODELS[variogram_model]
params = _fit_variogram(lag_h, lag_sv, model_func)
def vario_func(h):
return model_func(h, *params)
# Build kriging matrix
K_inv = _build_kriging_matrix(x_arr, y_arr, vario_func)
x_grid, y_grid = extract_grid_coords(template, func_name='kriging')
if K_inv is None:
pred = np.full(template.shape, np.nan)
prediction = xr.DataArray(
pred, name=name,
coords=template.coords,
dims=template.dims,
attrs=template.attrs,
)
if return_variance:
return prediction, prediction.copy()
return prediction
mapper = ArrayTypeFunctionMapping(
numpy_func=_kriging_numpy,
cupy_func=_kriging_cupy,
dask_func=_kriging_dask_numpy,
dask_cupy_func=_kriging_dask_cupy,
)
pred_arr, var_arr = mapper(template)(
x_arr, y_arr, z_arr, x_grid, y_grid,
vario_func, K_inv, return_variance, template.data,
)
prediction = xr.DataArray(
pred_arr, name=name,
coords=template.coords,
dims=template.dims,
attrs=template.attrs,
)
if return_variance:
variance = xr.DataArray(
var_arr, name=f'{name}_variance',
coords=template.coords,
dims=template.dims,
attrs=template.attrs,
)
return prediction, variance
return prediction