Source code for xrspatial.resample

"""Raster resampling -- resolution change without reprojection.

Provides :func:`resample` for changing raster cell size using
interpolation or block-aggregation methods.
"""
from __future__ import annotations

from functools import partial

import numpy as np
import xarray as xr
from scipy.ndimage import map_coordinates as _scipy_map_coords
from scipy.ndimage import spline_filter as _scipy_spline_filter

try:
    import dask.array as da
except ImportError:
    da = None

try:
    import cupy
except ImportError:
    cupy = None

from xrspatial.dataset_support import supports_dataset
from xrspatial.utils import ArrayTypeFunctionMapping, _validate_raster, calc_res, ngjit

# -- Constants ---------------------------------------------------------------

INTERP_METHODS = {'nearest': 0, 'bilinear': 1, 'cubic': 3}
AGGREGATE_METHODS = {'average', 'min', 'max', 'median', 'mode'}
ALL_METHODS = set(INTERP_METHODS) | AGGREGATE_METHODS

# Overlap depth (input pixels) each interpolation kernel needs from
# neighbouring chunks when processing dask arrays.  Cubic requires extra
# depth because the B-spline prefilter is a global IIR filter whose
# boundary transient decays as ~0.268^n.  Depth 16 puts the residual at
# ~7e-10, comfortably below float32 epsilon so chunk-seam parity rounds
# to zero in the float32 output.  The dask drivers clamp this per axis
# down to ``axis_total - 1`` when the array is too small to absorb the
# full depth; see ``_run_dask_numpy`` for the rationale.
_INTERP_DEPTH = {'nearest': 1, 'bilinear': 1, 'cubic': 16}

# Approximate working-set size per output cell for the eager backends:
# one float64 working buffer (8 B) plus a float64 output cell (8 B) in
# the worst case. scipy.ndimage.map_coordinates also allocates a
# temporary of the same size during higher-order spline evaluation; the
# 0.5 * available bound below leaves room for that.
_BYTES_PER_OUTPUT_CELL = 16


# -- Working / output dtype selection ----------------------------------------

def _working_dtype(input_dtype):
    """Pick the working float dtype for resampling.

    float64 inputs stay in float64 to preserve precision; everything else
    (smaller floats, integers, bool) uses float32.
    """
    dt = np.dtype(input_dtype)
    if dt.kind == 'f' and dt.itemsize >= 8:
        return np.float64
    return np.float32


def _output_dtype(input_dtype):
    """Pick the output dtype for resampling.

    Float inputs keep their dtype. Integer / bool inputs return float32
    because NaN-sentinel resampling needs a float type.
    """
    dt = np.dtype(input_dtype)
    if dt.kind == 'f':
        return dt.type
    return np.float32


def _maybe_astype(arr, dtype):
    """astype copy that no-ops when already at the requested dtype."""
    return arr if arr.dtype == np.dtype(dtype) else arr.astype(dtype)


# -- Memory guard ------------------------------------------------------------

def _available_memory_bytes():
    """Best-effort estimate of available host memory in bytes."""
    # Try /proc/meminfo (Linux)
    try:
        with open('/proc/meminfo', 'r') as f:
            for line in f:
                if line.startswith('MemAvailable:'):
                    return int(line.split()[1]) * 1024
    except (OSError, ValueError, IndexError):
        pass
    # Try psutil
    try:
        import psutil
        return psutil.virtual_memory().available
    except (ImportError, AttributeError):
        pass
    # Fallback: 2 GB
    return 2 * 1024 ** 3


def _available_gpu_memory_bytes():
    """Best-effort estimate of free GPU memory in bytes.

    Returns 0 when CuPy / CUDA is unavailable or the query fails -- callers
    treat that as a sentinel meaning "no GPU info, skip the guard".
    """
    try:
        import cupy as _cp
        free, _total = _cp.cuda.runtime.memGetInfo()
        return int(free)
    except Exception:
        return 0


def _check_resample_memory(out_h, out_w):
    """Raise MemoryError if the eager output buffer would exceed RAM.

    The numpy and cupy-eager backends allocate a single (out_h, out_w)
    float64 working buffer plus a float32 output before any actual work.
    A user passing a huge ``scale_factor`` (or a tiny ``target_resolution``)
    would otherwise OOM the process before this function returns.
    """
    required = int(out_h) * int(out_w) * _BYTES_PER_OUTPUT_CELL
    available = _available_memory_bytes()
    if required > 0.5 * available:
        raise MemoryError(
            f"resample output of {out_h}x{out_w} would need "
            f"~{required / 1e9:.1f} GB of working memory but only "
            f"~{available / 1e9:.1f} GB is available. "
            f"Use a smaller scale_factor / larger target_resolution, "
            f"or pass a dask-backed DataArray for out-of-core processing."
        )


def _check_resample_gpu_memory(out_h, out_w):
    """Raise MemoryError if the cupy-eager output buffer would exceed VRAM.

    Skips the check (returns silently) when free GPU memory cannot be
    queried -- the kernel will fail later at the cupy.empty boundary
    anyway.
    """
    available = _available_gpu_memory_bytes()
    if available <= 0:
        return
    required = int(out_h) * int(out_w) * _BYTES_PER_OUTPUT_CELL
    if required > 0.5 * available:
        raise MemoryError(
            f"resample output of {out_h}x{out_w} would need "
            f"~{required / 1e9:.1f} GB of GPU working memory but only "
            f"~{available / 1e9:.1f} GB is free on the active device. "
            f"Use a smaller scale_factor / larger target_resolution, "
            f"or pass a dask+cupy DataArray for out-of-core processing."
        )


# -- Input-validation helpers ------------------------------------------------

def _validate_resample_scalar_or_pair(value, param_name):
    """Validate a scalar-or-2-tuple resolution / scale parameter.

    Accepts either a real scalar or a length-2 tuple/list of scalars.
    Each component must be finite (not NaN, not inf) and strictly
    positive. Raises ``ValueError`` with a message naming the parameter
    and the offending value.
    """
    is_pair = isinstance(value, (tuple, list))
    if is_pair:
        if len(value) != 2:
            raise ValueError(
                f"{param_name} must have length 2, got length {len(value)}"
            )
        components = value
    else:
        components = (value,)

    for i, comp in enumerate(components):
        # Suffix points at the bad slot when the input was a pair, so
        # `(0.0, 1.0)` reports "got 0.0 at index 0 of (0.0, 1.0)"
        # instead of dumping the whole tuple.
        where = f"{comp!r} at index {i} of {value!r}" if is_pair else f"{value!r}"
        try:
            f = float(comp)
        except (TypeError, ValueError):
            raise ValueError(
                f"{param_name} must be a finite positive number "
                f"(or length-2 sequence of them), got {where}"
            ) from None
        if not np.isfinite(f):
            raise ValueError(
                f"{param_name} must be finite and > 0, got {where}"
            )
        if f <= 0:
            raise ValueError(
                f"{param_name} must be > 0, got {where}"
            )


def _validate_monotonic_regular_coords(agg):
    """Reject inputs whose spatial coords are not regular and monotonic.

    ``resample`` assumes a regular, monotonic grid: ``calc_res`` derives
    the input resolution from the full coordinate extent while the output
    coordinates are rebuilt from first/last neighbour spacing. On an
    irregular or non-monotonic grid those two views of "resolution"
    disagree and the function silently produces inconsistent output
    geometry (wrong width, coords spilling past the input range). Fail
    fast here instead.

    Only 1-D coords that actually exist on the spatial dims are checked;
    an input without spatial coords is left to the existing code paths.
    For 3-D inputs ``resample`` recurses per band, so this runs once per
    band on identical coords -- a cheap, harmless repeat.
    """
    for dim in agg.dims[-2:]:
        if dim not in agg.coords:
            continue
        vals = np.asarray(agg[dim].values, dtype=np.float64)
        if vals.ndim != 1 or vals.size < 2:
            continue
        diffs = np.diff(vals)
        if not (np.all(diffs > 0) or np.all(diffs < 0)):
            raise ValueError(
                f"resample(): `agg` coordinate {dim!r} must be strictly "
                f"monotonic (consistently increasing or decreasing); "
                f"resample only supports regular monotonic rasters"
            )
        # Allow floating-point jitter but reject genuinely uneven spacing
        # (e.g. [0, 1, 4]). Compare every step to the mean step. The
        # tolerance scales with the step size via ``rtol`` so it tracks
        # the coordinate magnitude.
        step = diffs.mean()
        if not np.allclose(diffs, step, rtol=1e-5, atol=0.0):
            raise ValueError(
                f"resample(): `agg` coordinate {dim!r} must be evenly "
                f"spaced; resample only supports regular monotonic "
                f"rasters, not irregular grids"
            )


# -- Output-geometry helpers -------------------------------------------------

def _output_shape(in_h, in_w, scale_y, scale_x):
    return max(1, round(in_h * scale_y)), max(1, round(in_w * scale_x))


def _output_chunks(in_chunks, scale):
    """Compute per-chunk output sizes via cumulative rounding.

    Guarantees ``sum(result) == round(sum(in_chunks) * scale)``.
    """
    cum = np.cumsum([0] + list(in_chunks))
    out_cum = np.round(cum * scale).astype(int)
    return tuple(int(max(1, out_cum[i + 1] - out_cum[i]))
                 for i in range(len(in_chunks)))


# -- Block-centered coordinate mapping ---------------------------------------

def _block_centered_coords(n_in, n_out):
    """Return input coordinates for each output pixel using block-centered mapping.

    Maps output pixel ``o`` to input pixel ``(o + 0.5) * (n_in / n_out) - 0.5``.
    This places each output pixel at the center of its spatial footprint,
    matching the convention used by ``_new_coords`` for output coordinate
    metadata.
    """
    o = np.arange(n_out, dtype=np.float64)
    return (o + 0.5) * (n_in / n_out) - 0.5


# -- Spline prefilter helpers -----------------------------------------------
#
# scipy.ndimage.map_coordinates(prefilter=True) silently does three things:
# (1) edge-pad the input by 12 pixels for mode='nearest' / 'grid-constant'
# so the IIR transient stabilises before reaching real data,
# (2) call spline_filter on that padded array, and
# (3) shift the sample coordinates by the same offset.  The padding step
# is private (``_prepad_for_spline_filter``) and is needed for the
# explicit-prefilter path to match the implicit one bit-for-bit.
#
# We replicate it here so callers can prefilter once per array (e.g. the
# NaN-aware filled / weights pair) and pass ``prefilter=False`` to
# map_coordinates without changing the boundary semantics.  Doing the
# prefilter explicitly also makes the per-block dask path deterministic --
# the same spline coefficients are computed in eager and chunked modes
# (modulo the IIR transient that the depth=10 overlap already absorbs).

_SPLINE_PREPAD_NEAREST = 12


def _prepad_and_filter_np(arr, order):
    """Edge-pad and spline-filter *arr* for an explicit ``mode='nearest'``
    prefilter pass.  Returns ``(filtered, npad)``; the caller adds *npad*
    to its sample coordinates.
    """
    npad = _SPLINE_PREPAD_NEAREST
    padded = np.pad(arr, npad, mode='edge')
    filtered = _scipy_spline_filter(padded, order=order, mode='nearest')
    return filtered, npad


def _prepad_and_filter_cupy(arr, order, spline_filter_fn):
    """CuPy variant of :func:`_prepad_and_filter_np`."""
    npad = _SPLINE_PREPAD_NEAREST
    padded = cupy.pad(arr, npad, mode='edge')
    filtered = spline_filter_fn(padded, order=order, mode='nearest')
    return filtered, npad


# -- NaN-aware interpolation (NumPy) ----------------------------------------

def _nan_aware_interp_np(data, out_h, out_w, order):
    """Interpolate *data* to *(out_h, out_w)* with NaN-aware weighting.

    Uses ``scipy.ndimage.map_coordinates`` with block-centered coordinate
    mapping so that sample positions match the output coordinate metadata.

    For *order* 0 (nearest-neighbour) NaN propagates naturally.
    For higher orders the zero-fill / weight-mask trick is used so that
    NaN pixels do not corrupt their neighbours.
    """
    iy = _block_centered_coords(data.shape[0], out_h)
    ix = _block_centered_coords(data.shape[1], out_w)
    yy, xx = np.meshgrid(iy, ix, indexing='ij')
    coords = np.array([yy.ravel(), xx.ravel()])

    if order == 0:
        result = _scipy_map_coords(data, coords, order=0, mode='nearest')
        return result.reshape(out_h, out_w)

    # For order >= 2 run the spline prefilter explicitly so the IIR boundary
    # transient is computed once per array instead of implicitly inside each
    # map_coordinates call.  Bilinear (order == 1) prefilter is a no-op.
    use_explicit = order >= 2

    mask = np.isnan(data)
    if not mask.any():
        if use_explicit:
            src, npad = _prepad_and_filter_np(data, order)
            result = _scipy_map_coords(src, coords + npad, order=order,
                                       mode='nearest', prefilter=False)
        else:
            result = _scipy_map_coords(data, coords, order=order,
                                       mode='nearest')
        return result.reshape(out_h, out_w)

    filled = np.where(mask, 0.0, data)
    weights = (~mask).astype(data.dtype)

    if use_explicit:
        filled, npad = _prepad_and_filter_np(filled, order)
        weights, _ = _prepad_and_filter_np(weights, order)
        sample_coords = coords + npad
        z_data = _scipy_map_coords(filled, sample_coords, order=order,
                                   mode='nearest', prefilter=False)
        z_wt = _scipy_map_coords(weights, sample_coords, order=order,
                                 mode='nearest', prefilter=False)
    else:
        z_data = _scipy_map_coords(filled, coords, order=order,
                                   mode='nearest')
        z_wt = _scipy_map_coords(weights, coords, order=order,
                                 mode='nearest')

    # Gate on majority weight: an output pixel is valid only when more
    # than half of the resampling kernel weight came from valid input
    # pixels.  This rejects pixels lit only by cubic-kernel sidelobes
    # leaking small positive weight from a single neighbour.
    result = np.where(z_wt > 0.5,
                      z_data / np.maximum(z_wt, 1e-10),
                      np.nan)
    return result.reshape(out_h, out_w)


# -- NaN-aware interpolation (CuPy) -----------------------------------------

def _nan_aware_interp_cupy(data, out_h, out_w, order):
    """CuPy variant of :func:`_nan_aware_interp_np`."""
    from cupyx.scipy.ndimage import map_coordinates as _cupy_map_coords
    from cupyx.scipy.ndimage import spline_filter as _cupy_spline_filter

    iy = cupy.asarray(_block_centered_coords(data.shape[0], out_h))
    ix = cupy.asarray(_block_centered_coords(data.shape[1], out_w))
    yy, xx = cupy.meshgrid(iy, ix, indexing='ij')
    coords = cupy.array([yy.ravel(), xx.ravel()])

    if order == 0:
        result = _cupy_map_coords(data, coords, order=0, mode='nearest')
        return result.reshape(out_h, out_w)

    use_explicit = order >= 2

    mask = cupy.isnan(data)
    if not mask.any():
        if use_explicit:
            src, npad = _prepad_and_filter_cupy(data, order, _cupy_spline_filter)
            result = _cupy_map_coords(src, coords + npad, order=order,
                                      mode='nearest', prefilter=False)
        else:
            result = _cupy_map_coords(data, coords, order=order,
                                      mode='nearest')
        return result.reshape(out_h, out_w)

    filled = cupy.where(mask, 0.0, data)
    weights = (~mask).astype(data.dtype)

    if use_explicit:
        filled, npad = _prepad_and_filter_cupy(filled, order, _cupy_spline_filter)
        weights, _ = _prepad_and_filter_cupy(weights, order, _cupy_spline_filter)
        sample_coords = coords + npad
        z_data = _cupy_map_coords(filled, sample_coords, order=order,
                                  mode='nearest', prefilter=False)
        z_wt = _cupy_map_coords(weights, sample_coords, order=order,
                                mode='nearest', prefilter=False)
    else:
        z_data = _cupy_map_coords(filled, coords, order=order,
                                  mode='nearest')
        z_wt = _cupy_map_coords(weights, coords, order=order,
                                mode='nearest')

    # Majority-weight gate (see _nan_aware_interp_np for rationale).
    result = cupy.where(z_wt > 0.5,
                        z_data / cupy.maximum(z_wt, 1e-10),
                        cupy.nan)
    return result.reshape(out_h, out_w)


# -- Block-aggregation kernels (NumPy, numba) --------------------------------

@ngjit
def _agg_mean(data, out_h, out_w):
    h, w = data.shape
    out = np.empty((out_h, out_w), dtype=np.float64)
    for oy in range(out_h):
        y0 = int(oy * h / out_h)
        y1 = max(y0 + 1, int((oy + 1) * h / out_h))
        for ox in range(out_w):
            x0 = int(ox * w / out_w)
            x1 = max(x0 + 1, int((ox + 1) * w / out_w))
            total = 0.0
            count = 0
            for y in range(y0, y1):
                for x in range(x0, x1):
                    v = data[y, x]
                    if not np.isnan(v):
                        total += v
                        count += 1
            out[oy, ox] = total / count if count > 0 else np.nan
    return out


@ngjit
def _agg_min(data, out_h, out_w):
    h, w = data.shape
    out = np.empty((out_h, out_w), dtype=np.float64)
    for oy in range(out_h):
        y0 = int(oy * h / out_h)
        y1 = max(y0 + 1, int((oy + 1) * h / out_h))
        for ox in range(out_w):
            x0 = int(ox * w / out_w)
            x1 = max(x0 + 1, int((ox + 1) * w / out_w))
            best = np.inf
            found = False
            for y in range(y0, y1):
                for x in range(x0, x1):
                    v = data[y, x]
                    if not np.isnan(v) and v < best:
                        best = v
                        found = True
            out[oy, ox] = best if found else np.nan
    return out


@ngjit
def _agg_max(data, out_h, out_w):
    h, w = data.shape
    out = np.empty((out_h, out_w), dtype=np.float64)
    for oy in range(out_h):
        y0 = int(oy * h / out_h)
        y1 = max(y0 + 1, int((oy + 1) * h / out_h))
        for ox in range(out_w):
            x0 = int(ox * w / out_w)
            x1 = max(x0 + 1, int((ox + 1) * w / out_w))
            best = -np.inf
            found = False
            for y in range(y0, y1):
                for x in range(x0, x1):
                    v = data[y, x]
                    if not np.isnan(v) and v > best:
                        best = v
                        found = True
            out[oy, ox] = best if found else np.nan
    return out


@ngjit
def _agg_median(data, out_h, out_w):
    h, w = data.shape
    out = np.empty((out_h, out_w), dtype=np.float64)
    for oy in range(out_h):
        y0 = int(oy * h / out_h)
        y1 = max(y0 + 1, int((oy + 1) * h / out_h))
        for ox in range(out_w):
            x0 = int(ox * w / out_w)
            x1 = max(x0 + 1, int((ox + 1) * w / out_w))
            buf = np.empty((y1 - y0) * (x1 - x0), dtype=np.float64)
            n = 0
            for y in range(y0, y1):
                for x in range(x0, x1):
                    v = data[y, x]
                    if not np.isnan(v):
                        buf[n] = v
                        n += 1
            if n == 0:
                out[oy, ox] = np.nan
            else:
                s = np.sort(buf[:n])
                if n % 2 == 1:
                    out[oy, ox] = s[n // 2]
                else:
                    out[oy, ox] = (s[n // 2 - 1] + s[n // 2]) / 2.0
    return out


@ngjit
def _agg_mode(data, out_h, out_w):
    h, w = data.shape
    out = np.empty((out_h, out_w), dtype=np.float64)
    for oy in range(out_h):
        y0 = int(oy * h / out_h)
        y1 = max(y0 + 1, int((oy + 1) * h / out_h))
        for ox in range(out_w):
            x0 = int(ox * w / out_w)
            x1 = max(x0 + 1, int((ox + 1) * w / out_w))
            buf = np.empty((y1 - y0) * (x1 - x0), dtype=np.float64)
            n = 0
            for y in range(y0, y1):
                for x in range(x0, x1):
                    v = data[y, x]
                    if not np.isnan(v):
                        buf[n] = v
                        n += 1
            if n == 0:
                out[oy, ox] = np.nan
                continue
            s = np.sort(buf[:n])
            best_val = s[0]
            best_cnt = 1
            cur_val = s[0]
            cur_cnt = 1
            for i in range(1, n):
                if s[i] == cur_val:
                    cur_cnt += 1
                else:
                    if cur_cnt > best_cnt:
                        best_cnt = cur_cnt
                        best_val = cur_val
                    cur_val = s[i]
                    cur_cnt = 1
            if cur_cnt > best_cnt:
                best_val = cur_val
            out[oy, ox] = best_val
    return out


_AGG_FUNCS = {
    'average': _agg_mean,
    'min': _agg_min,
    'max': _agg_max,
    'median': _agg_median,
    'mode': _agg_mode,
}


# -- Block-aggregation kernels for dask chunks -------------------------------
#
# These mirror the eager `_agg_mean / _agg_min / ...` family but compute
# per-pixel windows from the *global* input/output geometry and a chunk
# offset, rather than from the local block shape.  The whole chunk runs
# inside a single jitted call, instead of one numba dispatch per output
# pixel as the previous `func(sub, 1, 1)[0, 0]` loop did.
#
# Window bounds for output pixel `go` (a *global* output index):
#     gy0 = int(go * global_in_h / global_out_h) - in_y0
#     gy1 = max(gy0 + 1,
#               int((go + 1) * global_in_h / global_out_h) - in_y0)
# where `in_y0` is the global input index of the chunk's first row
# (negative if `_add_overlap` extended the chunk past the input edge).

@ngjit
def _agg_block_mean_nb(data, target_h, target_w,
                       go_y0, go_x0,
                       global_in_h, global_in_w,
                       global_out_h, global_out_w,
                       in_y0, in_x0):
    out = np.empty((target_h, target_w), dtype=np.float64)
    for lo_y in range(target_h):
        go_y = go_y0 + lo_y
        gy0 = int(go_y * global_in_h / global_out_h) - in_y0
        gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
        if gy1 < gy0 + 1:
            gy1 = gy0 + 1
        for lo_x in range(target_w):
            go_x = go_x0 + lo_x
            gx0 = int(go_x * global_in_w / global_out_w) - in_x0
            gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
            if gx1 < gx0 + 1:
                gx1 = gx0 + 1
            total = 0.0
            count = 0
            for y in range(gy0, gy1):
                for x in range(gx0, gx1):
                    v = data[y, x]
                    if not np.isnan(v):
                        total += v
                        count += 1
            out[lo_y, lo_x] = total / count if count > 0 else np.nan
    return out


@ngjit
def _agg_block_min_nb(data, target_h, target_w,
                      go_y0, go_x0,
                      global_in_h, global_in_w,
                      global_out_h, global_out_w,
                      in_y0, in_x0):
    out = np.empty((target_h, target_w), dtype=np.float64)
    for lo_y in range(target_h):
        go_y = go_y0 + lo_y
        gy0 = int(go_y * global_in_h / global_out_h) - in_y0
        gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
        if gy1 < gy0 + 1:
            gy1 = gy0 + 1
        for lo_x in range(target_w):
            go_x = go_x0 + lo_x
            gx0 = int(go_x * global_in_w / global_out_w) - in_x0
            gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
            if gx1 < gx0 + 1:
                gx1 = gx0 + 1
            best = np.inf
            found = False
            for y in range(gy0, gy1):
                for x in range(gx0, gx1):
                    v = data[y, x]
                    if not np.isnan(v) and v < best:
                        best = v
                        found = True
            out[lo_y, lo_x] = best if found else np.nan
    return out


@ngjit
def _agg_block_max_nb(data, target_h, target_w,
                      go_y0, go_x0,
                      global_in_h, global_in_w,
                      global_out_h, global_out_w,
                      in_y0, in_x0):
    out = np.empty((target_h, target_w), dtype=np.float64)
    for lo_y in range(target_h):
        go_y = go_y0 + lo_y
        gy0 = int(go_y * global_in_h / global_out_h) - in_y0
        gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
        if gy1 < gy0 + 1:
            gy1 = gy0 + 1
        for lo_x in range(target_w):
            go_x = go_x0 + lo_x
            gx0 = int(go_x * global_in_w / global_out_w) - in_x0
            gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
            if gx1 < gx0 + 1:
                gx1 = gx0 + 1
            best = -np.inf
            found = False
            for y in range(gy0, gy1):
                for x in range(gx0, gx1):
                    v = data[y, x]
                    if not np.isnan(v) and v > best:
                        best = v
                        found = True
            out[lo_y, lo_x] = best if found else np.nan
    return out


@ngjit
def _agg_block_median_nb(data, target_h, target_w,
                         go_y0, go_x0,
                         global_in_h, global_in_w,
                         global_out_h, global_out_w,
                         in_y0, in_x0):
    out = np.empty((target_h, target_w), dtype=np.float64)
    for lo_y in range(target_h):
        go_y = go_y0 + lo_y
        gy0 = int(go_y * global_in_h / global_out_h) - in_y0
        gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
        if gy1 < gy0 + 1:
            gy1 = gy0 + 1
        for lo_x in range(target_w):
            go_x = go_x0 + lo_x
            gx0 = int(go_x * global_in_w / global_out_w) - in_x0
            gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
            if gx1 < gx0 + 1:
                gx1 = gx0 + 1
            buf = np.empty((gy1 - gy0) * (gx1 - gx0), dtype=np.float64)
            n = 0
            for y in range(gy0, gy1):
                for x in range(gx0, gx1):
                    v = data[y, x]
                    if not np.isnan(v):
                        buf[n] = v
                        n += 1
            if n == 0:
                out[lo_y, lo_x] = np.nan
            else:
                s = np.sort(buf[:n])
                if n % 2 == 1:
                    out[lo_y, lo_x] = s[n // 2]
                else:
                    out[lo_y, lo_x] = (s[n // 2 - 1] + s[n // 2]) / 2.0
    return out


@ngjit
def _agg_block_mode_nb(data, target_h, target_w,
                       go_y0, go_x0,
                       global_in_h, global_in_w,
                       global_out_h, global_out_w,
                       in_y0, in_x0):
    out = np.empty((target_h, target_w), dtype=np.float64)
    for lo_y in range(target_h):
        go_y = go_y0 + lo_y
        gy0 = int(go_y * global_in_h / global_out_h) - in_y0
        gy1 = int((go_y + 1) * global_in_h / global_out_h) - in_y0
        if gy1 < gy0 + 1:
            gy1 = gy0 + 1
        for lo_x in range(target_w):
            go_x = go_x0 + lo_x
            gx0 = int(go_x * global_in_w / global_out_w) - in_x0
            gx1 = int((go_x + 1) * global_in_w / global_out_w) - in_x0
            if gx1 < gx0 + 1:
                gx1 = gx0 + 1
            buf = np.empty((gy1 - gy0) * (gx1 - gx0), dtype=np.float64)
            n = 0
            for y in range(gy0, gy1):
                for x in range(gx0, gx1):
                    v = data[y, x]
                    if not np.isnan(v):
                        buf[n] = v
                        n += 1
            if n == 0:
                out[lo_y, lo_x] = np.nan
                continue
            s = np.sort(buf[:n])
            best_val = s[0]
            best_cnt = 1
            cur_val = s[0]
            cur_cnt = 1
            for i in range(1, n):
                if s[i] == cur_val:
                    cur_cnt += 1
                else:
                    if cur_cnt > best_cnt:
                        best_cnt = cur_cnt
                        best_val = cur_val
                    cur_val = s[i]
                    cur_cnt = 1
            if cur_cnt > best_cnt:
                best_val = cur_val
            out[lo_y, lo_x] = best_val
    return out


_AGG_BLOCK_FUNCS = {
    'average': _agg_block_mean_nb,
    'min': _agg_block_min_nb,
    'max': _agg_block_max_nb,
    'median': _agg_block_median_nb,
    'mode': _agg_block_mode_nb,
}


# -- Dask block helpers ------------------------------------------------------
#
# Interpolation uses map_coordinates with *global* coordinate mapping so
# that results are identical regardless of chunk layout.  Each block
# receives the cumulative chunk boundaries and computes which global
# output pixels it is responsible for, maps them back to global input
# coordinates, then converts to local (within-block) coordinates.


def _interp_block_np(block, global_in_h, global_in_w,
                     global_out_h, global_out_w,
                     cum_in_y, cum_in_x, cum_out_y, cum_out_x,
                     depth_y, depth_x, order, work_dtype, out_dtype,
                     block_info=None):
    """Interpolate one (possibly overlapped) numpy block."""
    yi, xi = block_info[0]['chunk-location']
    target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
    target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])

    block = _maybe_astype(block, work_dtype)

    # Global output pixel indices for this chunk
    oy = np.arange(cum_out_y[yi], cum_out_y[yi + 1], dtype=np.float64)
    ox = np.arange(cum_out_x[xi], cum_out_x[xi + 1], dtype=np.float64)

    # Map to global input coordinates using block-centered formula
    iy = (oy + 0.5) * (global_in_h / global_out_h) - 0.5
    ix = (ox + 0.5) * (global_in_w / global_out_w) - 0.5

    # Convert to local block coordinates (overlap shifts the origin)
    iy_local = iy - (cum_in_y[yi] - depth_y)
    ix_local = ix - (cum_in_x[xi] - depth_x)

    yy, xx = np.meshgrid(iy_local, ix_local, indexing='ij')
    coords = np.array([yy.ravel(), xx.ravel()])

    # NaN-aware interpolation.  For order >= 2 we run the spline prefilter
    # explicitly per array (block / filled / weights) so the IIR boundary
    # transient is identical between eager and chunked paths.
    use_explicit = order >= 2

    mask = np.isnan(block)
    if order == 0 or not mask.any():
        if use_explicit:
            src, npad = _prepad_and_filter_np(block, order)
            result = _scipy_map_coords(src, coords + npad, order=order,
                                       mode='nearest', prefilter=False)
        else:
            result = _scipy_map_coords(block, coords, order=order,
                                       mode='nearest')
    else:
        filled = np.where(mask, 0.0, block)
        weights = (~mask).astype(block.dtype)
        if use_explicit:
            filled, npad = _prepad_and_filter_np(filled, order)
            weights, _ = _prepad_and_filter_np(weights, order)
            sample_coords = coords + npad
            z_data = _scipy_map_coords(filled, sample_coords, order=order,
                                       mode='nearest', prefilter=False)
            z_wt = _scipy_map_coords(weights, sample_coords, order=order,
                                     mode='nearest', prefilter=False)
        else:
            z_data = _scipy_map_coords(filled, coords, order=order,
                                       mode='nearest')
            z_wt = _scipy_map_coords(weights, coords, order=order,
                                     mode='nearest')
        # Majority-weight gate (see _nan_aware_interp_np for rationale).
        result = np.where(z_wt > 0.5,
                          z_data / np.maximum(z_wt, 1e-10), np.nan)

    return _maybe_astype(result.reshape(target_h, target_w), out_dtype)


def _interp_block_cupy(block, global_in_h, global_in_w,
                       global_out_h, global_out_w,
                       cum_in_y, cum_in_x, cum_out_y, cum_out_x,
                       depth_y, depth_x, order, work_dtype, out_dtype,
                       block_info=None):
    """CuPy variant of :func:`_interp_block_np`."""
    from cupyx.scipy.ndimage import map_coordinates as _cupy_map_coords
    from cupyx.scipy.ndimage import spline_filter as _cupy_spline_filter

    yi, xi = block_info[0]['chunk-location']
    target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
    target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])

    if block.dtype != cupy.dtype(work_dtype):
        block = block.astype(work_dtype)

    oy = cupy.arange(int(cum_out_y[yi]), int(cum_out_y[yi + 1]),
                     dtype=cupy.float64)
    ox = cupy.arange(int(cum_out_x[xi]), int(cum_out_x[xi + 1]),
                     dtype=cupy.float64)

    # Map to global input coordinates using block-centered formula
    iy = (oy + 0.5) * (global_in_h / global_out_h) - 0.5
    ix = (ox + 0.5) * (global_in_w / global_out_w) - 0.5

    iy_local = iy - float(cum_in_y[yi] - depth_y)
    ix_local = ix - float(cum_in_x[xi] - depth_x)

    yy, xx = cupy.meshgrid(iy_local, ix_local, indexing='ij')
    coords = cupy.array([yy.ravel(), xx.ravel()])

    use_explicit = order >= 2

    mask = cupy.isnan(block)
    if order == 0 or not mask.any():
        if use_explicit:
            src, npad = _prepad_and_filter_cupy(block, order, _cupy_spline_filter)
            result = _cupy_map_coords(src, coords + npad, order=order,
                                      mode='nearest', prefilter=False)
        else:
            result = _cupy_map_coords(block, coords, order=order,
                                      mode='nearest')
    else:
        filled = cupy.where(mask, 0.0, block)
        weights = (~mask).astype(block.dtype)
        if use_explicit:
            filled, npad = _prepad_and_filter_cupy(filled, order, _cupy_spline_filter)
            weights, _ = _prepad_and_filter_cupy(weights, order, _cupy_spline_filter)
            sample_coords = coords + npad
            z_data = _cupy_map_coords(filled, sample_coords, order=order,
                                      mode='nearest', prefilter=False)
            z_wt = _cupy_map_coords(weights, sample_coords, order=order,
                                    mode='nearest', prefilter=False)
        else:
            z_data = _cupy_map_coords(filled, coords, order=order,
                                      mode='nearest')
            z_wt = _cupy_map_coords(weights, coords, order=order,
                                    mode='nearest')
        # Majority-weight gate (see _nan_aware_interp_np for rationale).
        result = cupy.where(z_wt > 0.5,
                            z_data / cupy.maximum(z_wt, 1e-10), cupy.nan)

    result = result.reshape(target_h, target_w)
    if result.dtype != cupy.dtype(out_dtype):
        result = result.astype(out_dtype)
    return result


def _agg_block_np(block, method, global_in_h, global_in_w,
                  global_out_h, global_out_w,
                  cum_in_y, cum_in_x, cum_out_y, cum_out_x,
                  depth_y, depth_x, out_dtype, block_info=None):
    """Block-aggregate one (possibly overlapped) numpy chunk.

    Runs the entire chunk inside one numba dispatch via the
    `_agg_block_*_nb` kernels.  Earlier versions called a 1x1 jitted
    aggregate per output pixel, which scaled badly for large rasters.
    """
    yi, xi = block_info[0]['chunk-location']
    target_h = int(cum_out_y[yi + 1] - cum_out_y[yi])
    target_w = int(cum_out_x[xi + 1] - cum_out_x[xi])

    # _AGG_FUNCS kernels are @ngjit-compiled with hard-coded float64
    # working buffers; cast accordingly so numba dispatch matches.
    block = _maybe_astype(block, np.float64)
    # The overlapped block starts depth pixels before the original chunk
    in_y0 = int(cum_in_y[yi]) - depth_y
    in_x0 = int(cum_in_x[xi]) - depth_x
    go_y0 = int(cum_out_y[yi])
    go_x0 = int(cum_out_x[xi])

    kernel = _AGG_BLOCK_FUNCS[method]
    out = kernel(block, target_h, target_w,
                 go_y0, go_x0,
                 int(global_in_h), int(global_in_w),
                 int(global_out_h), int(global_out_w),
                 in_y0, in_x0)
    return _maybe_astype(out, out_dtype)


def _agg_block_cupy(block, method, global_in_h, global_in_w,
                    global_out_h, global_out_w,
                    cum_in_y, cum_in_x, cum_out_y, cum_out_x,
                    depth_y, depth_x, out_dtype, block_info=None):
    """Block-aggregate one cupy chunk (falls back to CPU)."""
    cpu = cupy.asnumpy(block)
    result = _agg_block_np(
        cpu, method, global_in_h, global_in_w,
        global_out_h, global_out_w,
        cum_in_y, cum_in_x, cum_out_y, cum_out_x,
        depth_y, depth_x, out_dtype, block_info=block_info,
    )
    return cupy.asarray(result)


# -- Per-backend runners -----------------------------------------------------

def _run_numpy(data, scale_y, scale_x, method):
    work_dt = _working_dtype(data.dtype)
    out_dt = _output_dtype(data.dtype)
    data = _maybe_astype(data, work_dt)
    out_h, out_w = _output_shape(*data.shape, scale_y, scale_x)

    if method in INTERP_METHODS:
        result = _nan_aware_interp_np(data, out_h, out_w,
                                      INTERP_METHODS[method])
        return _maybe_astype(result, out_dt)

    result = _AGG_FUNCS[method](data, out_h, out_w)
    return _maybe_astype(result, out_dt)


def _run_cupy(data, scale_y, scale_x, method):
    work_dt = _working_dtype(data.dtype)
    out_dt = _output_dtype(data.dtype)
    data = data if data.dtype == cupy.dtype(work_dt) else data.astype(work_dt)
    out_h, out_w = _output_shape(*data.shape, scale_y, scale_x)

    if method in INTERP_METHODS:
        result = _nan_aware_interp_cupy(data, out_h, out_w,
                                        INTERP_METHODS[method])
        return result if result.dtype == cupy.dtype(out_dt) else result.astype(out_dt)

    # Aggregate: GPU reshape+reduce for integer factors, CPU fallback otherwise
    fy, fx = data.shape[0] / out_h, data.shape[1] / out_w
    if (fy == int(fy) and fx == int(fx)
            and method in ('average', 'min', 'max')):
        fy, fx = int(fy), int(fx)
        trimmed = data[:out_h * fy, :out_w * fx]
        reshaped = trimmed.reshape(out_h, fy, out_w, fx)
        reducer = {'average': cupy.nanmean,
                   'min': cupy.nanmin,
                   'max': cupy.nanmax}[method]
        result = reducer(reshaped, axis=(1, 3))
        return result if result.dtype == cupy.dtype(out_dt) else result.astype(out_dt)

    cpu = cupy.asnumpy(data)
    return cupy.asarray(
        _maybe_astype(_AGG_FUNCS[method](cpu, out_h, out_w), out_dt)
    )


def _min_chunksize_for_scale(scale):
    """Minimum input chunk size so that no output chunk is zero after rounding."""
    if scale >= 1.0:
        return 1
    # c > 1/s guarantees round((k+1)*c*s) - round(k*c*s) >= 1 for all k.
    return int(1.0 / scale) + 1


def _downsample_radius(scale):
    """Extra interp overlap (input pixels) needed for a downsample on one axis.

    Block-centered mapping sends output pixel ``o`` to input coordinate
    ``(o + 0.5) * (in/out) - 0.5``. When ``scale < 1`` (downsampling), the
    source coordinate of an output pixel near a chunk seam can sit up to
    about ``(in/out)/2`` input pixels beyond the chunk ``_output_chunks``
    assigned it to. Returning ``ceil((1/scale)/2) + 1`` covers that
    displacement (the ``+1`` absorbs the half-pixel coordinate offset and
    the cumulative-rounding mismatch between ``_output_chunks`` and the
    per-pixel mapping). Upsampling needs none, so return 0 for ``scale >= 1``.
    """
    import math
    if scale >= 1.0:
        return 0
    return int(math.ceil((1.0 / scale) / 2.0)) + 1


def _ensure_min_chunksize(data, min_size):
    """Rechunk *data* so every chunk is at least *min_size* pixels wide."""
    import math
    new = {}
    for ax in range(data.ndim):
        if any(c < min_size for c in data.chunks[ax]):
            total = sum(data.chunks[ax])
            # Find chunk size where ALL chunks (including last) >= min_size
            n = max(1, total // min_size)
            cs = math.ceil(total / n)
            while n > 1:
                remainder = total - cs * (total // cs)
                if remainder == 0 or remainder >= min_size:
                    break
                n -= 1
                cs = math.ceil(total / n)
            new[ax] = cs
    return data.rechunk(new) if new else data


def _run_dask_numpy(data, scale_y, scale_x, method):
    work_dt = _working_dtype(data.dtype)
    out_dt = _output_dtype(data.dtype)
    if data.dtype != np.dtype(work_dt):
        data = data.astype(work_dt)
    meta = np.array((), dtype=out_dt)

    if method in INTERP_METHODS:
        order = INTERP_METHODS[method]
        depth = _INTERP_DEPTH[method]

        # When downsampling, an output pixel's block-centered source
        # coordinate can land ~(in/out)/2 input pixels past the chunk
        # _output_chunks assigned it to. _INTERP_DEPTH only covers the
        # kernel stencil, not that displacement, so add the per-axis
        # downsample radius. Without it the overlapped block is missing the
        # true source row/column and map_coordinates clamps to the block
        # edge, corrupting whole chunk-seam rows (issue #2610).
        depth_y_base = depth + _downsample_radius(scale_y)
        depth_x_base = depth + _downsample_radius(scale_x)

        # Clamp depth per axis so it never exceeds the array's total size on
        # that axis. dask.overlap rejects ``depth > sum(chunks)``, which would
        # otherwise blow up for inputs smaller than the cubic prefilter depth
        # (e.g. an Nx1 column). The eager kernels have no overlap and accept
        # arbitrarily small inputs; clamping preserves that behaviour while
        # keeping the full depth wherever the axis is large enough.
        global_in_h = int(sum(data.chunks[0]))
        global_in_w = int(sum(data.chunks[1]))
        depth_y = min(depth_y_base, max(0, global_in_h - 1))
        depth_x = min(depth_x_base, max(0, global_in_w - 1))

        min_size = max(2 * max(depth_y, depth_x) + 1,
                       _min_chunksize_for_scale(scale_y),
                       _min_chunksize_for_scale(scale_x))
        data = _ensure_min_chunksize(data, min_size)

        global_in_h = int(sum(data.chunks[0]))
        global_in_w = int(sum(data.chunks[1]))
        global_out_h, global_out_w = _output_shape(
            global_in_h, global_in_w, scale_y, scale_x)
        out_y = _output_chunks(data.chunks[0], scale_y)
        out_x = _output_chunks(data.chunks[1], scale_x)

        cum_in_y = np.cumsum([0] + list(data.chunks[0]))
        cum_in_x = np.cumsum([0] + list(data.chunks[1]))
        cum_out_y = np.cumsum([0] + list(out_y))
        cum_out_x = np.cumsum([0] + list(out_x))

        src = data
        if depth_y > 0 or depth_x > 0:
            from dask.array.overlap import overlap as _add_overlap
            src = _add_overlap(data, depth={0: depth_y, 1: depth_x},
                               boundary='nearest')

        fn = partial(_interp_block_np,
                     global_in_h=global_in_h, global_in_w=global_in_w,
                     global_out_h=global_out_h, global_out_w=global_out_w,
                     cum_in_y=cum_in_y, cum_in_x=cum_in_x,
                     cum_out_y=cum_out_y, cum_out_x=cum_out_x,
                     depth_y=depth_y, depth_x=depth_x, order=order,
                     work_dtype=work_dt, out_dtype=out_dt)
        return da.map_blocks(fn, src, chunks=(out_y, out_x),
                             dtype=out_dt, meta=meta)

    import math

    # Aggregate windows can cross chunk boundaries; size chunks to satisfy
    # both the scale-driven minimum and the depth-driven minimum in one pass,
    # then build the cumulative arrays once.
    global_in_h = int(sum(data.chunks[0]))
    global_in_w = int(sum(data.chunks[1]))
    global_out_h, global_out_w = _output_shape(
        global_in_h, global_in_w, scale_y, scale_x)
    depth_y = math.ceil(global_in_h / global_out_h)
    depth_x = math.ceil(global_in_w / global_out_w)
    min_size = max(_min_chunksize_for_scale(scale_y),
                   _min_chunksize_for_scale(scale_x),
                   2 * depth_y + 1, 2 * depth_x + 1)
    data = _ensure_min_chunksize(data, min_size)

    out_y = _output_chunks(data.chunks[0], scale_y)
    out_x = _output_chunks(data.chunks[1], scale_x)
    cum_in_y = np.cumsum([0] + list(data.chunks[0]))
    cum_in_x = np.cumsum([0] + list(data.chunks[1]))
    cum_out_y = np.cumsum([0] + list(out_y))
    cum_out_x = np.cumsum([0] + list(out_x))

    # boundary=np.nan keeps overlap padding from contaminating the aggregate
    # at the global edges. The kernels skip NaN inputs and return NaN for
    # empty windows, so the padded region is ignored naturally.
    from dask.array.overlap import overlap as _add_overlap
    src = _add_overlap(data, depth={0: depth_y, 1: depth_x},
                       boundary=np.nan)

    fn = partial(_agg_block_np, method=method,
                 global_in_h=global_in_h, global_in_w=global_in_w,
                 global_out_h=global_out_h, global_out_w=global_out_w,
                 cum_in_y=cum_in_y, cum_in_x=cum_in_x,
                 cum_out_y=cum_out_y, cum_out_x=cum_out_x,
                 depth_y=depth_y, depth_x=depth_x,
                 out_dtype=out_dt)
    return da.map_blocks(fn, src, chunks=(out_y, out_x),
                         dtype=out_dt, meta=meta)


def _run_dask_cupy(data, scale_y, scale_x, method):
    work_dt = _working_dtype(data.dtype)
    out_dt = _output_dtype(data.dtype)
    if data.dtype != cupy.dtype(work_dt):
        data = data.astype(work_dt)
    meta = cupy.array((), dtype=out_dt)

    if method in INTERP_METHODS:
        order = INTERP_METHODS[method]
        depth = _INTERP_DEPTH[method]

        # Add the per-axis downsample radius before clamping (see
        # _run_dask_numpy and _downsample_radius for the rationale; #2610).
        depth_y_base = depth + _downsample_radius(scale_y)
        depth_x_base = depth + _downsample_radius(scale_x)

        # Clamp depth per axis (see _run_dask_numpy for rationale).
        global_in_h = int(sum(data.chunks[0]))
        global_in_w = int(sum(data.chunks[1]))
        depth_y = min(depth_y_base, max(0, global_in_h - 1))
        depth_x = min(depth_x_base, max(0, global_in_w - 1))

        min_size = max(2 * max(depth_y, depth_x) + 1,
                       _min_chunksize_for_scale(scale_y),
                       _min_chunksize_for_scale(scale_x))
        data = _ensure_min_chunksize(data, min_size)

        global_in_h = int(sum(data.chunks[0]))
        global_in_w = int(sum(data.chunks[1]))
        global_out_h, global_out_w = _output_shape(
            global_in_h, global_in_w, scale_y, scale_x)
        out_y = _output_chunks(data.chunks[0], scale_y)
        out_x = _output_chunks(data.chunks[1], scale_x)

        cum_in_y = np.cumsum([0] + list(data.chunks[0]))
        cum_in_x = np.cumsum([0] + list(data.chunks[1]))
        cum_out_y = np.cumsum([0] + list(out_y))
        cum_out_x = np.cumsum([0] + list(out_x))

        src = data
        if depth_y > 0 or depth_x > 0:
            from dask.array.overlap import overlap as _add_overlap
            src = _add_overlap(data, depth={0: depth_y, 1: depth_x},
                               boundary='nearest')

        fn = partial(_interp_block_cupy,
                     global_in_h=global_in_h, global_in_w=global_in_w,
                     global_out_h=global_out_h, global_out_w=global_out_w,
                     cum_in_y=cum_in_y, cum_in_x=cum_in_x,
                     cum_out_y=cum_out_y, cum_out_x=cum_out_x,
                     depth_y=depth_y, depth_x=depth_x, order=order,
                     work_dtype=work_dt, out_dtype=out_dt)
        return da.map_blocks(fn, src, chunks=(out_y, out_x),
                             dtype=out_dt, meta=meta)

    import math

    # Aggregate windows can cross chunk boundaries; size chunks to satisfy
    # both the scale-driven minimum and the depth-driven minimum in one pass,
    # then build the cumulative arrays once.
    global_in_h = int(sum(data.chunks[0]))
    global_in_w = int(sum(data.chunks[1]))
    global_out_h, global_out_w = _output_shape(
        global_in_h, global_in_w, scale_y, scale_x)
    depth_y = math.ceil(global_in_h / global_out_h)
    depth_x = math.ceil(global_in_w / global_out_w)
    min_size = max(_min_chunksize_for_scale(scale_y),
                   _min_chunksize_for_scale(scale_x),
                   2 * depth_y + 1, 2 * depth_x + 1)
    data = _ensure_min_chunksize(data, min_size)

    out_y = _output_chunks(data.chunks[0], scale_y)
    out_x = _output_chunks(data.chunks[1], scale_x)
    cum_in_y = np.cumsum([0] + list(data.chunks[0]))
    cum_in_x = np.cumsum([0] + list(data.chunks[1]))
    cum_out_y = np.cumsum([0] + list(out_y))
    cum_out_x = np.cumsum([0] + list(out_x))

    # boundary=np.nan keeps overlap padding from contaminating the aggregate
    # at the global edges. The kernels skip NaN inputs and return NaN for
    # empty windows, so the padded region is ignored naturally.
    from dask.array.overlap import overlap as _add_overlap
    src = _add_overlap(data, depth={0: depth_y, 1: depth_x},
                       boundary=cupy.nan)

    fn = partial(_agg_block_cupy, method=method,
                 global_in_h=global_in_h, global_in_w=global_in_w,
                 global_out_h=global_out_h, global_out_w=global_out_w,
                 cum_in_y=cum_in_y, cum_in_x=cum_in_x,
                 cum_out_y=cum_out_y, cum_out_x=cum_out_x,
                 depth_y=depth_y, depth_x=depth_x,
                 out_dtype=out_dt)
    return da.map_blocks(fn, src, chunks=(out_y, out_x),
                         dtype=out_dt, meta=meta)


# -- Public API --------------------------------------------------------------

def _resolve_nodata(agg, nodata):
    """Resolve the input-side nodata sentinel.

    Explicit *nodata* wins. Otherwise fall back to ``_FillValue`` then
    ``nodata`` in ``agg.attrs``. Returns ``None`` when no sentinel was
    found (the caller skips the masking step).

    For floating-point inputs the sentinel is returned as a Python
    ``float`` so the caller can branch on ``np.isnan`` rather than
    ``==`` (which never matches NaN). For integer / bool inputs the
    sentinel is cast to the input dtype so the comparison happens in
    integer space -- routing it through ``float`` would lose precision
    for int64 values above 2**53.
    """
    if nodata is None:
        for key in ('_FillValue', 'nodata'):
            v = agg.attrs.get(key)
            if v is not None:
                nodata = v
                break
    if nodata is None:
        return None
    if np.issubdtype(agg.dtype, np.floating):
        nd = float(nodata)
        if np.isinf(nd):
            raise ValueError(f"nodata must be finite or NaN, got {nodata!r}")
        return nd
    # Integer / bool input: keep the sentinel in the input's native
    # dtype so the equality test in _apply_nodata_mask compares
    # integer-to-integer. A NaN sentinel can never match an integer
    # value, so signal a no-op mask by returning NaN unchanged.
    if isinstance(nodata, float) and np.isnan(nodata):
        return float('nan')
    # Reject fractional float sentinels for integer inputs -- silently
    # truncating to int would mask cells the caller never asked to mask.
    if isinstance(nodata, float) and not nodata.is_integer():
        raise ValueError(
            f"nodata={nodata!r} is not representable in integer dtype "
            f"{agg.dtype}; pass an integer sentinel instead."
        )
    # Integer inputs: an out-of-range sentinel wraps on cast (e.g. 999
    # becomes 231 for uint8), masking the wrong cells. Require the value
    # to round-trip exactly into agg.dtype before trusting the cast.
    info = np.iinfo(agg.dtype)
    nd_int = int(nodata)
    # A sentinel beyond the dtype range either wraps (numpy fixed-width
    # cast) or overflows the C-long conversion for very large Python
    # ints. Range-check up front so both surface the same ValueError
    # instead of a raw OverflowError.
    if nd_int < info.min or nd_int > info.max:
        raise ValueError(
            f"nodata={nodata!r} is out of range for integer dtype "
            f"{agg.dtype} (valid range [{info.min}, {info.max}])."
        )
    return np.asarray(nd_int).astype(agg.dtype).item()


def _apply_nodata_mask(agg, nodata):
    """Return a float copy of *agg* with sentinel pixels replaced by NaN.

    Works for numpy, cupy, dask+numpy, and dask+cupy backings via
    xarray's ``.where`` (which dispatches per backend).
    """
    if nodata is None:
        return agg
    is_float_input = np.issubdtype(agg.dtype, np.floating)
    # For floating-point input a NaN sentinel needs no replacement
    # (NaN is already the output convention). For integer input a NaN
    # sentinel can never match any cell, so the mask is a no-op; still
    # promote to float so downstream NaN handling has somewhere to
    # write its sentinels.
    if is_float_input and isinstance(nodata, float) and np.isnan(nodata):
        return agg
    # Compare in the input dtype FIRST so integer comparisons keep
    # full precision (float64 cannot represent int64 values above
    # 2**53 without rounding). Then promote to float so NaN can be
    # stored in the masked output.
    mask = agg != nodata
    if not is_float_input:
        agg = agg.astype(np.float32)
    return agg.where(mask)


def _refresh_nodata_attrs(src_attrs, dst_attrs):
    """Refresh nodata sentinels in *dst_attrs* to NaN.

    Resample replaces sentinel pixels with NaN regardless of input
    dtype. If the input declared a sentinel via ``_FillValue``,
    ``nodatavals``, or the rasterio-style ``nodata`` attr, refresh each
    one to NaN so the metadata matches the actual data. Keys absent on
    the input stay absent. ``_resolve_nodata`` reads ``nodata`` as a
    fallback, so a stale finite value there would silently mismatch the
    masked data on any downstream consumer that trusts
    ``attrs['nodata']``.
    """
    if '_FillValue' in src_attrs:
        dst_attrs['_FillValue'] = float('nan')
    if 'nodatavals' in src_attrs:
        old = src_attrs['nodatavals']
        dst_attrs['nodatavals'] = tuple(float('nan') for _ in old)
    if 'nodata' in src_attrs:
        dst_attrs['nodata'] = float('nan')


[docs] @supports_dataset def resample( agg: xr.DataArray, scale_factor: float | tuple[float, float] | None = None, target_resolution: float | tuple[float, float] | None = None, method: str = 'nearest', nodata: float | None = None, name: str = 'resample', ) -> xr.DataArray: """Change raster resolution without changing its CRS. Exactly one of *scale_factor* or *target_resolution* must be given. Parameters ---------- agg : xarray.DataArray or xarray.Dataset Input raster. 2-D ``(y, x)`` or 3-D ``(band, y, x)``. For 3-D inputs each band is resampled independently and the leading non-spatial coordinate is preserved. If a Dataset is passed, the operation is applied to each data variable independently (via the ``@supports_dataset`` decorator). scale_factor : float or (float, float), optional Multiplicative factor applied to the number of pixels. ``0.5`` halves the pixel count (doubles the cell size); ``2.0`` doubles the pixel count (halves the cell size). A two-element tuple sets ``(scale_y, scale_x)`` independently. target_resolution : float or (float, float), optional Desired cell size in the same units as the raster coordinates. A scalar sets both axes to the same resolution; a 2-tuple sets ``(res_y, res_x)`` independently. method : str, default ``'nearest'`` Resampling algorithm. Interpolation methods (``'nearest'``, ``'bilinear'``, ``'cubic'``) work for both upsampling and downsampling. Aggregation methods (``'average'``, ``'min'``, ``'max'``, ``'median'``, ``'mode'``) only support downsampling (scale_factor <= 1). nodata : float, optional Sentinel value in the input that should be treated as missing. Input pixels equal to *nodata* are replaced with NaN before resampling. When ``None``, falls back to ``agg.attrs['_FillValue']`` then ``agg.attrs['nodata']``. The output uses NaN as the sentinel regardless of the input convention. name : str, default ``'resample'`` Name for the output DataArray. Returns ------- xarray.DataArray Resampled raster with updated coordinates and ``res`` attribute. Output dtype matches the input float dtype (float32 or float64); integer inputs return float32 since NaN-sentinel resampling requires a float type. Raises ------ ValueError If ``agg`` has a zero-length spatial dimension; if neither or both of ``scale_factor`` and ``target_resolution`` are given; if either is a sequence whose length is not 2; if any component is zero, negative, NaN, or infinite; if ``method`` is not in :data:`ALL_METHODS`; if the spatial coordinates of ``agg`` are not strictly monotonic and evenly spaced (``resample`` only supports regular monotonic rasters); or if ``nodata`` does not round-trip exactly into an integer ``agg.dtype`` (a fractional or out-of-range sentinel that would wrap on the cast). """ _validate_raster(agg, func_name='resample', name='agg', ndim=(2, 3)) _validate_monotonic_regular_coords(agg) # Reject empty rasters up front. A zero-length spatial axis would # otherwise reach the output-coordinate rebuild and surface as an # opaque IndexError (vals[0] on an empty coord array) rather than a # clear, parameter-named error. if agg.shape[-2] == 0 or agg.shape[-1] == 0: raise ValueError( f"resample(): `agg` must have non-empty spatial dimensions, " f"got shape {tuple(agg.shape)}" ) if method not in ALL_METHODS: raise ValueError( f"method must be one of {sorted(ALL_METHODS)}, got {method!r}" ) # -- resolve scale factors ----------------------------------------------- if (scale_factor is None) == (target_resolution is None): raise ValueError( "Exactly one of scale_factor or target_resolution must be given" ) # Validate shape, finiteness, and positivity of whichever input was # supplied. Fails fast with a parameter-named message before any # geometry math runs, so overlong/short tuples, zero, and NaN/inf # do not surface later as IndexError / ZeroDivisionError / opaque # numpy conversion errors. if target_resolution is not None: _validate_resample_scalar_or_pair( target_resolution, 'target_resolution' ) else: _validate_resample_scalar_or_pair(scale_factor, 'scale_factor') if target_resolution is not None: if agg.shape[-2] < 2 or agg.shape[-1] < 2: raise ValueError( "target_resolution requires at least 2 pixels per dimension" ) res_x, res_y = calc_res(agg) if isinstance(target_resolution, (tuple, list)): scale_y = abs(res_y) / target_resolution[0] scale_x = abs(res_x) / target_resolution[1] else: scale_y = abs(res_y) / target_resolution scale_x = abs(res_x) / target_resolution elif isinstance(scale_factor, (tuple, list)): scale_y, scale_x = float(scale_factor[0]), float(scale_factor[1]) else: scale_y = scale_x = float(scale_factor) # Defence-in-depth: the public inputs were already validated above # by ``_validate_resample_scalar_or_pair``, so on the scale_factor # path this branch is unreachable. It still fires on the # target_resolution path if ``calc_res(agg)`` returns zero from a # degenerate coord array. if scale_y <= 0 or scale_x <= 0: raise ValueError( f"Scale factors must be positive, got ({scale_y}, {scale_x})" ) if method in AGGREGATE_METHODS and (scale_y > 1.0 or scale_x > 1.0): raise ValueError( f"Aggregate method {method!r} only supports downsampling " f"(scale_factor <= 1.0)" ) # -- nodata: replace sentinels with NaN before resampling ---------------- nd_resolved = _resolve_nodata(agg, nodata) has_nodata = nd_resolved is not None if has_nodata: agg = _apply_nodata_mask(agg, nd_resolved) # -- fast path: identity ------------------------------------------------- if scale_y == 1.0 and scale_x == 1.0: out = agg.copy() out.name = name # When nodata was applied, advertise NaN as the new sentinel. if has_nodata: # Always advertise NaN via `_FillValue` -- this also covers the # explicit `nodata=` case where the input carried no nodata # attrs. Then refresh `nodata` / `nodatavals` for inputs that # did declare them, so masked-to-NaN output never advertises a # stale finite sentinel (the non-identity path does the same). out.attrs['_FillValue'] = float('nan') _refresh_nodata_attrs(agg.attrs, out.attrs) return out # -- 3D: dispatch per band ---------------------------------------------- if agg.ndim == 3: leading_dim = agg.dims[0] bands = [] for i in range(agg.sizes[leading_dim]): band_2d = agg.isel({leading_dim: i}) band_out = resample( band_2d, scale_factor=scale_factor, target_resolution=target_resolution, method=method, # Pass NaN so the recursive call short-circuits masking # (we already applied the mask on the 3D input above) and # ignores the original attrs sentinel. nodata=float('nan'), name=name, ) bands.append(band_out) # Stack along the leading dim. concat preserves the per-band # coordinate when each input has it. result = xr.concat(bands, dim=leading_dim) # concat may reorder dims; transpose to the original layout. result = result.transpose(*agg.dims) result.name = name # Carry across input attrs (concat picks the first; merge with input). new_attrs = dict(agg.attrs) new_attrs.update(bands[0].attrs) # res from per-band resample if has_nodata: new_attrs['_FillValue'] = float('nan') _refresh_nodata_attrs(agg.attrs, new_attrs) result.attrs = new_attrs # Preserve the leading-dim coordinate if it was on the input. if leading_dim in agg.coords: result = result.assign_coords({leading_dim: agg.coords[leading_dim]}) return result # -- memory guard for eager backends ------------------------------------ # Dask paths build per-chunk allocations lazily (chunk size already # bounds peak memory). The eager numpy and cupy paths allocate the # full (out_h, out_w) buffer up front and need an explicit guard. in_h, in_w = agg.shape[-2:] out_h, out_w = _output_shape(in_h, in_w, scale_y, scale_x) is_dask = da is not None and isinstance(agg.data, da.Array) is_cupy = cupy is not None and isinstance(agg.data, cupy.ndarray) if not is_dask: if is_cupy: _check_resample_gpu_memory(out_h, out_w) else: _check_resample_memory(out_h, out_w) # -- dispatch to backend ------------------------------------------------- mapper = ArrayTypeFunctionMapping( numpy_func=_run_numpy, cupy_func=_run_cupy, dask_func=_run_dask_numpy, dask_cupy_func=_run_dask_cupy, ) result_data = mapper(agg)(agg.data, scale_y, scale_x, method) # -- build output coordinates ------------------------------------------- ydim, xdim = agg.dims[-2], agg.dims[-1] y_vals = np.asarray(agg[ydim].values, dtype=np.float64) x_vals = np.asarray(agg[xdim].values, dtype=np.float64) def _new_coords(vals, n_out): if len(vals) > 1: half_first = (vals[1] - vals[0]) / 2 half_last = (vals[-1] - vals[-2]) / 2 else: half_first = half_last = 0.5 edge_start = vals[0] - half_first edge_end = vals[-1] + half_last px = (edge_end - edge_start) / n_out coords = np.linspace(edge_start + px / 2, edge_end - px / 2, n_out) return coords, px, edge_start, edge_end new_y, py, y_edge_start, y_edge_end = _new_coords(y_vals, out_h) new_x, px, x_edge_start, x_edge_end = _new_coords(x_vals, out_w) new_attrs = dict(agg.attrs) new_attrs['res'] = (abs(px), abs(py)) if has_nodata: new_attrs['_FillValue'] = float('nan') # Refresh `transform` if the input had one. The rasterio 6-tuple # `(a, b, c, d, e, f)` maps `(col, row) -> (x, y)` for the first # array pixel at `(col=0, row=0)`, so the scale signs and the # origin corner have to follow the actual array layout rather # than assuming a north-up grid. `px` / `py` from `_new_coords` # are already signed (positive when the coord ascends along the # axis, negative when it descends), and `*_edge_start` is the # leading edge of the first row / column on the side of # `vals[0]` -- exactly what rasterio wants for `c` and `f`. if 'transform' in agg.attrs: new_attrs['transform'] = ( px, 0.0, x_edge_start, 0.0, py, y_edge_start, ) _refresh_nodata_attrs(agg.attrs, new_attrs) # Carry across scalar (zero-dim) non-dim coords like rioxarray's # `spatial_ref` or a squeezed `time` / `band` selector. The # identity path (scale==1.0) preserves these via `agg.copy()`; # the 2D non-identity path must match so chained rioxarray # pipelines don't silently lose CRS / spatial_ref / scalar # selector coords. Spatially-shaped non-dim coords (dims include # ydim or xdim) are not carried because their length changed. extra_coords = {} for coord_name, coord in agg.coords.items(): if coord_name in (ydim, xdim): continue # spatial dim-coords are rebuilt above if coord.ndim == 0: extra_coords[coord_name] = coord result = xr.DataArray( result_data, name=name, dims=agg.dims, coords={ydim: new_y, xdim: new_x, **extra_coords}, attrs=new_attrs, ) for dim in (ydim, xdim): if dim in agg.coords and agg[dim].attrs: result[dim].attrs = dict(agg[dim].attrs) return result