Source code for xrspatial.interpolate._spline

"""Thin Plate Spline (TPS) interpolation."""

from __future__ import annotations

import math

import numpy as np
import xarray as xr
from numba import cuda

from xrspatial.utils import (
    ArrayTypeFunctionMapping,
    _validate_raster,
    _validate_scalar,
    cuda_args,
    ngjit,
)

from ._validation import extract_grid_coords, validate_points

try:
    import cupy
except ImportError:
    cupy = None

try:
    import dask.array as da
except ImportError:
    da = None


# ---------------------------------------------------------------------------
# TPS system build (always CPU)
# ---------------------------------------------------------------------------

def _tps_radial(r2):
    """U(r) = r^2 ln(r) = 0.5 * r^2 * ln(r^2) for r > 0, else 0."""
    with np.errstate(divide='ignore', invalid='ignore'):
        result = np.where(r2 > 0, 0.5 * r2 * np.log(r2), 0.0)
    return result


def _tps_build_and_solve(x_pts, y_pts, z_pts, smoothing):
    """Build and solve the (N+3)x(N+3) TPS system.

    Returns the weight vector of length N+3:
        [w_0, ..., w_{N-1}, a_0, a_1, a_2]
    """
    n = len(x_pts)

    # With fewer than 3 points the full affine+radial system is
    # underdetermined.  Fall back to a simple constant/affine fit.
    if n < 3:
        weights = np.zeros(n + 3, dtype=np.float64)
        if n == 1:
            weights[n] = z_pts[0]  # constant surface
        else:
            # n == 2: fit z = a0 + a1*x + a2*y via least-squares
            P = np.ones((n, 3), dtype=np.float64)
            P[:, 1] = x_pts
            P[:, 2] = y_pts
            a, _, _, _ = np.linalg.lstsq(P, z_pts, rcond=None)
            weights[n:] = a
        return weights

    # K block: K[i,j] = U(||p_i - p_j||)
    dx = x_pts[:, np.newaxis] - x_pts[np.newaxis, :]
    dy = y_pts[:, np.newaxis] - y_pts[np.newaxis, :]
    r2 = dx ** 2 + dy ** 2
    K = _tps_radial(r2)

    # Regularisation
    K += (smoothing + 1e-10) * np.eye(n)

    # P block: [1, x_i, y_i]
    P = np.ones((n, 3), dtype=np.float64)
    P[:, 1] = x_pts
    P[:, 2] = y_pts

    # Assemble full system
    A = np.zeros((n + 3, n + 3), dtype=np.float64)
    A[:n, :n] = K
    A[:n, n:] = P
    A[n:, :n] = P.T

    b = np.zeros(n + 3, dtype=np.float64)
    b[:n] = z_pts

    try:
        weights = np.linalg.solve(A, b)
    except np.linalg.LinAlgError:
        # Degenerate geometry (e.g. collinear points)
        weights, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
    return weights


# ---------------------------------------------------------------------------
# CPU evaluation kernel (numba JIT)
# ---------------------------------------------------------------------------

@ngjit
def _tps_evaluate_cpu(x_pts, y_pts, weights, n_pts, x_grid, y_grid):
    ny = y_grid.shape[0]
    nx = x_grid.shape[0]
    out = np.empty((ny, nx), dtype=np.float64)

    for i in range(ny):
        for j in range(nx):
            gx = x_grid[j]
            gy = y_grid[i]
            val = (weights[n_pts]
                   + weights[n_pts + 1] * gx
                   + weights[n_pts + 2] * gy)

            for p in range(n_pts):
                dx = gx - x_pts[p]
                dy = gy - y_pts[p]
                r2 = dx * dx + dy * dy
                if r2 > 0.0:
                    val += weights[p] * r2 * math.log(r2) * 0.5

            out[i, j] = val

    return out


# ---------------------------------------------------------------------------
# Numpy backend
# ---------------------------------------------------------------------------

def _spline_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
                  smoothing, weights, template_data):
    n = len(x_pts)
    return _tps_evaluate_cpu(x_pts, y_pts, weights, n, x_grid, y_grid)


# ---------------------------------------------------------------------------
# CUDA evaluation kernel
# ---------------------------------------------------------------------------

@cuda.jit
def _tps_cuda_kernel(x_pts, y_pts, weights, n_pts, x_grid, y_grid, out):
    i, j = cuda.grid(2)
    if i < out.shape[0] and j < out.shape[1]:
        gx = x_grid[j]
        gy = y_grid[i]
        val = (weights[n_pts]
               + weights[n_pts + 1] * gx
               + weights[n_pts + 2] * gy)

        for p in range(n_pts):
            dx = gx - x_pts[p]
            dy = gy - y_pts[p]
            r2 = dx * dx + dy * dy
            if r2 > 0.0:
                val += weights[p] * r2 * math.log(r2) * 0.5

        out[i, j] = val


# ---------------------------------------------------------------------------
# CuPy backend (CPU solve + GPU evaluate)
# ---------------------------------------------------------------------------

def _spline_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
                 smoothing, weights, template_data):
    n = len(x_pts)
    x_gpu = cupy.asarray(x_pts)
    y_gpu = cupy.asarray(y_pts)
    w_gpu = cupy.asarray(weights)
    xg_gpu = cupy.asarray(x_grid)
    yg_gpu = cupy.asarray(y_grid)

    ny, nx = len(y_grid), len(x_grid)
    out = cupy.empty((ny, nx), dtype=np.float64)

    griddim, blockdim = cuda_args((ny, nx))
    _tps_cuda_kernel[griddim, blockdim](
        x_gpu, y_gpu, w_gpu, n, xg_gpu, yg_gpu, out,
    )
    return out


# ---------------------------------------------------------------------------
# Dask + numpy backend
# ---------------------------------------------------------------------------

def _spline_dask_numpy(x_pts, y_pts, z_pts, x_grid, y_grid,
                       smoothing, weights, template_data):

    def _chunk(block, block_info=None):
        if block_info is None:
            return block
        loc = block_info[0]['array-location']
        y_sl = y_grid[loc[0][0]:loc[0][1]]
        x_sl = x_grid[loc[1][0]:loc[1][1]]
        return _spline_numpy(x_pts, y_pts, z_pts, x_sl, y_sl,
                             smoothing, weights, None)

    return da.map_blocks(_chunk, template_data, dtype=np.float64)


# ---------------------------------------------------------------------------
# Dask + cupy backend
# ---------------------------------------------------------------------------

def _spline_dask_cupy(x_pts, y_pts, z_pts, x_grid, y_grid,
                      smoothing, weights, template_data):

    def _chunk(block, block_info=None):
        if block_info is None:
            return block
        loc = block_info[0]['array-location']
        y_sl = y_grid[loc[0][0]:loc[0][1]]
        x_sl = x_grid[loc[1][0]:loc[1][1]]
        return _spline_cupy(x_pts, y_pts, z_pts, x_sl, y_sl,
                            smoothing, weights, None)

    return da.map_blocks(
        _chunk, template_data, dtype=np.float64,
        meta=cupy.array((), dtype=np.float64),
    )


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def _check_spline_memory(n_points, grid_pixels):
    """Raise MemoryError if spline() would exceed available memory.

    Two allocations dominate TPS memory use:

    * Kernel build.  Lines 67-70 construct ``dx``, ``dy``, ``r2`` and
      ``K`` each as N x N float64.  About ``4 * N**2 * 8`` bytes during
      construction.
    * Augmented system.  Line 81 builds ``A`` of shape ``(N+3, N+3)``
      float64, and ``np.linalg.solve`` copies it during LU
      factorization.  About ``2 * (N+3)**2 * 8`` bytes peak.

    The kernel build runs first and is freed before the solve, so peak
    usage is bounded by the larger of the two.  The grid is iterated
    point-by-point inside a numba kernel, so prediction does not
    materialize a grid x N matrix.
    """
    n = int(n_points)

    if n < 3:
        return  # short-circuit path, no big allocations

    kernel_bytes = 4 * n * n * 8
    solve_bytes = 2 * (n + 3) * (n + 3) * 8

    estimate = max(kernel_bytes, solve_bytes)

    try:
        from xrspatial.zonal import _available_memory_bytes
        avail = _available_memory_bytes()
    except ImportError:
        avail = 2 * 1024 ** 3

    # Match the kriging guard's 0.8 fraction (PR #1309) so users get
    # consistent behaviour across interpolators.  The TPS solve and the
    # kernel build do not run concurrently with other large arrays in
    # this function, so 80% of available RAM is the right ceiling.
    if estimate > 0.8 * avail:
        if estimate == solve_bytes:
            culprit = (
                f"augmented system A of shape ({n + 3}, {n + 3}) "
                f"plus its LU factorization copy"
            )
        else:
            culprit = (
                f"radial-basis kernel block K of shape ({n}, {n}) "
                f"(plus dx, dy, r2 of the same shape)"
            )

        raise MemoryError(
            f"spline() needs ~{estimate / 1e9:.1f} GB to allocate the "
            f"{culprit}, but only ~{avail / 1e9:.1f} GB is available. "
            f"Reduce the number of input points."
        )


[docs] def spline(x, y, z, template, smoothing=0.0, name='spline'): """Thin Plate Spline interpolation. Parameters ---------- x, y, z : array-like Coordinates and values of scattered input points. template : xr.DataArray 2-D DataArray whose grid defines the output raster. smoothing : float, default 0.0 Smoothing parameter. 0 forces exact interpolation through the data points. name : str, default 'spline' Name of the output DataArray. Returns ------- xr.DataArray Raises ------ MemoryError If the worst-case allocation (kernel block K or augmented system A) would exceed 80% of available memory. """ _validate_raster(template, func_name='spline', name='template') x_arr, y_arr, z_arr = validate_points(x, y, z, func_name='spline') _validate_scalar(smoothing, func_name='spline', name='smoothing', min_val=0.0) x_grid, y_grid = extract_grid_coords(template, func_name='spline') # Memory guard. Runs after input validation so we know N, but # before the TPS system is built. grid_pixels = int(np.prod(template.shape)) _check_spline_memory(len(x_arr), grid_pixels) # Solve TPS system once on CPU weights = _tps_build_and_solve(x_arr, y_arr, z_arr, smoothing) mapper = ArrayTypeFunctionMapping( numpy_func=_spline_numpy, cupy_func=_spline_cupy, dask_func=_spline_dask_numpy, dask_cupy_func=_spline_dask_cupy, ) out = mapper(template)( x_arr, y_arr, z_arr, x_grid, y_grid, smoothing, weights, template.data, ) return xr.DataArray( out, name=name, coords=template.coords, dims=template.dims, attrs=template.attrs, )