Source code for xrspatial.utils

from __future__ import annotations

from math import ceil
import warnings

import numpy as np
import xarray as xr
from numba import cuda, jit

try:
    import cupy
except ImportError:
    cupy = None


try:
    import dask.array as da
except ImportError:
    da = None


try:
    import dask.dataframe as dd
except ImportError:
    dd = None


ngjit = jit(nopython=True, nogil=True)


# ---------- Boundary mode utilities ----------
VALID_BOUNDARY_MODES = ('nan', 'nearest', 'reflect', 'wrap')


def _validate_boundary(boundary):
    """Raise ValueError if *boundary* is not a recognised mode."""
    if boundary not in VALID_BOUNDARY_MODES:
        raise ValueError(
            f"boundary must be one of {VALID_BOUNDARY_MODES}, "
            f"got {boundary!r}"
        )


def _validate_raster(
    agg,
    *,
    func_name: str,
    name: str = 'raster',
    ndim: int | tuple[int, ...] | None = 2,
    numeric: bool = True,
    integer_only: bool = False,
):
    """Validate that *agg* is an xarray.DataArray with expected properties.

    Parameters
    ----------
    agg : object
        Value to validate.
    func_name : str
        Name of the calling function (for error messages).
    name : str
        Parameter name (for error messages).
    ndim : int, tuple of int, or None
        Allowed number of dimensions.  ``None`` skips the check.
    numeric : bool
        If True, require a real numeric dtype (integer or float).
        Complex dtypes are rejected because xrspatial operations
        assume real-valued raster data.
    integer_only : bool
        If True, require an integer dtype specifically.

    Raises
    ------
    TypeError
        If *agg* is not an ``xr.DataArray``.
    ValueError
        If the dimensionality or dtype is wrong.
    """
    if not isinstance(agg, xr.DataArray):
        raise TypeError(
            f"{func_name}(): `{name}` must be an xarray.DataArray, "
            f"got {type(agg).__module__}.{type(agg).__qualname__}"
        )

    if ndim is not None:
        allowed = (ndim,) if isinstance(ndim, int) else tuple(ndim)
        if agg.ndim not in allowed:
            expected = 'or '.join(f'{d}D' for d in allowed)
            raise ValueError(
                f"{func_name}(): `{name}` must be {expected}, "
                f"got {agg.ndim}D"
            )

    if numeric:
        if integer_only:
            if not np.issubdtype(agg.dtype, np.integer):
                raise ValueError(
                    f"{func_name}(): `{name}` must have an integer dtype, "
                    f"got {agg.dtype}"
                )
        else:
            if (
                not np.issubdtype(agg.dtype, np.number)
                or np.issubdtype(agg.dtype, np.complexfloating)
            ):
                raise ValueError(
                    f"{func_name}(): `{name}` must have a real numeric "
                    f"dtype (integer or float), got {agg.dtype}"
                )


def _validate_matching_shape(
    agg,
    expected_shape,
    *,
    func_name: str,
    name: str = 'raster',
    expected_name: str = 'the primary raster',
):
    """Validate that *agg* has spatial shape ``expected_shape``.

    Used to confirm a companion raster (e.g. start points, pour points,
    flow accumulation) covers the same (H, W) grid as the primary input
    before any kernel indexes into it.

    Parameters
    ----------
    agg : xarray.DataArray
        Companion raster to validate.
    expected_shape : tuple of int
        Required ``(H, W)`` shape.
    func_name : str
        Name of the calling function (for error messages).
    name : str
        Parameter name (for error messages).
    expected_name : str
        Description of the raster whose shape is the reference.

    Raises
    ------
    ValueError
        If ``agg.shape`` does not equal ``expected_shape``.
    """
    if tuple(agg.shape) != tuple(expected_shape):
        raise ValueError(
            f"{func_name}(): `{name}` shape {tuple(agg.shape)} does not "
            f"match {expected_name} shape {tuple(expected_shape)}"
        )


def _validate_scalar(
    value,
    *,
    func_name: str,
    name: str,
    dtype: type | tuple = (int, float),
    min_val=None,
    max_val=None,
    min_exclusive: bool = False,
):
    """Validate that *value* is a scalar of the expected type and range.

    Parameters
    ----------
    value : object
        Value to validate.
    func_name : str
        Name of the calling function (for error messages).
    name : str
        Parameter name (for error messages).
    dtype : type or tuple of types
        Allowed Python types (checked with ``isinstance``).
    min_val, max_val : numeric or None
        Inclusive bounds (or exclusive lower bound if *min_exclusive*).
    min_exclusive : bool
        If True, the lower bound is exclusive (``>`` instead of ``>=``).

    Raises
    ------
    TypeError
        If *value* is not an instance of *dtype*.
    ValueError
        If *value* is outside the allowed range.
    """
    # Expand dtype to also accept numpy scalar equivalents so that
    # e.g. np.int64(5) passes a check for dtype=int.
    _dtype = dtype if isinstance(dtype, tuple) else (dtype,)
    _expanded = list(_dtype)
    if int in _dtype:
        _expanded.append(np.integer)
    if float in _dtype:
        _expanded.append(np.floating)
    _expanded = tuple(_expanded)

    if not isinstance(value, _expanded):
        expected = dtype.__name__ if isinstance(dtype, type) else \
            ' or '.join(t.__name__ for t in _dtype)
        raise TypeError(
            f"{func_name}(): `{name}` must be {expected}, "
            f"got {type(value).__name__}"
        )

    if min_val is not None:
        if min_exclusive:
            if value <= min_val:
                raise ValueError(
                    f"{func_name}(): `{name}` must be > {min_val}, "
                    f"got {value}"
                )
        else:
            if value < min_val:
                raise ValueError(
                    f"{func_name}(): `{name}` must be >= {min_val}, "
                    f"got {value}"
                )

    if max_val is not None:
        if value > max_val:
            raise ValueError(
                f"{func_name}(): `{name}` must be <= {max_val}, "
                f"got {value}"
            )


def _validate_mfd_fractions(data, *, func_name: str, name: str = 'fractions',
                            atol: float = 1e-6):
    """Validate the VALUES of a (8, H, W) MFD fraction grid.

    The public MFD functions document that each cell's 8 fraction
    bands lie in ``[0, 1]`` and sum to either 1.0 (flow) or 0.0
    (pit/flat/sink), with all-NaN bands at edges and nodata cells.
    This checks those value invariants and raises a clear error when
    the input violates them, before any hydrology math runs.

    Three checks per cell:

    * No negative fractions.
    * The band sum is 1.0 or 0.0 within *atol*.
    * NaN bands are all-or-nothing: either all 8 directions are NaN
      (edge/nodata) or none are.  A partially-NaN cell is rejected.

    Only numpy and cupy (in-memory) arrays are validated.  Dask arrays
    are skipped so validation does not force computation; lazy
    validation is handled separately.  The shape is assumed to already
    be ``(8, H, W)`` (callers check that first).

    Parameters
    ----------
    data : numpy.ndarray, cupy.ndarray, or dask.array.Array
        The fraction grid (``DataArray.data``).
    func_name : str
        Name of the calling function (for error messages).
    name : str
        Parameter name (for error messages).
    atol : float
        Absolute tolerance for the band-sum check.

    Raises
    ------
    ValueError
        If any cell has a negative fraction, a band sum that is
        neither ~1.0 nor ~0.0, or a partial-NaN band pattern.
    """
    if is_cupy_array(data):
        xp = cupy
    elif isinstance(data, np.ndarray):
        xp = np
    else:
        # Dask (numpy- or cupy-backed) or other lazy types: skip value
        # validation so we do not trigger computation.
        return

    prefix = f"{func_name}(): `{name}`"

    nan_count = xp.isnan(data).sum(axis=0)
    # Partial NaN: some but not all of the 8 bands are NaN.
    if bool(((nan_count > 0) & (nan_count < 8)).any()):
        raise ValueError(
            f"{prefix} has cells with a partial-NaN band pattern.  Each "
            f"cell must have all 8 direction bands NaN (edge/nodata) or "
            f"none of them NaN."
        )

    # NaN < 0 is False, so NaN cells never trip this (no copy needed).
    if bool((data < 0).any()):
        raise ValueError(
            f"{prefix} contains negative flow fractions.  Fractions must "
            f"be in [0, 1]."
        )

    # Per-cell band sums, treating NaN bands as 0 so all-NaN cells sum
    # to 0.0 and pass the sink check.
    sums = xp.nansum(data, axis=0)
    valid_cell = nan_count == 0
    bad_sum = valid_cell & ~(
        (xp.abs(sums - 1.0) <= atol) | (xp.abs(sums) <= atol)
    )
    if bool(bad_sum.any()):
        raise ValueError(
            f"{prefix} has cells whose flow fractions do not sum to 1.0 "
            f"(flow) or 0.0 (pit/flat/sink) within tolerance {atol}."
        )


def _boundary_to_dask(boundary, is_cupy=False):
    """Convert a boundary mode string to the value expected by
    ``dask.array.map_overlap``'s *boundary* parameter."""
    if boundary == 'nan':
        if is_cupy:
            import cupy as _cp
            return _cp.nan
        return np.nan
    _mode_map = {
        'nearest': 'nearest',
        'reflect': 'reflect',
        'wrap': 'periodic',
    }
    return _mode_map[boundary]


def _pad_array(data, depth, boundary):
    """Pad a 2-D numpy or cupy array according to *boundary* mode.

    Parameters
    ----------
    data : array-like
        2-D array to pad.
    depth : int or tuple of int
        Number of cells to pad on each side.  An int pads all axes
        equally; a tuple ``(d0, d1)`` pads each axis independently.
    boundary : str
        One of ``'nearest'``, ``'reflect'``, ``'wrap'``.
        ``'nan'`` should be handled before calling this function.
    """
    # numpy.pad 'symmetric' matches dask map_overlap 'reflect'
    # (both include the edge element in the reflection)
    _np_mode_map = {
        'nearest': 'edge',
        'reflect': 'symmetric',
        'wrap': 'wrap',
    }
    mode = _np_mode_map[boundary]
    if isinstance(depth, int):
        pad_width = ((depth, depth), (depth, depth))
    else:
        pad_width = tuple((d, d) for d in depth)
    if is_cupy_array(data):
        return cupy.pad(data, pad_width, mode=mode)
    return np.pad(data, pad_width, mode=mode)


def has_cuda_and_cupy():
    return _has_cuda() and _has_cupy()


def _has_cupy():
    return cupy is not None


def is_cupy_array(arr):
    return _has_cupy() and isinstance(arr, cupy.ndarray)


def has_dask_array():
    return da is not None


def has_dask_dataframe():
    return dd is not None


def _has_cuda():
    """Check for supported CUDA device. If none found, return False"""
    local_cuda = False
    try:
        cuda.cudadrv.devices.gpus.current
        local_cuda = True
    except cuda.cudadrv.error.CudaSupportError:
        local_cuda = False

    return local_cuda


def cuda_args(shape):
    """
    Compute the blocks-per-grid and threads-per-block parameters for
    use when invoking cuda kernels

    Parameters
    ----------
    shape: int or tuple of ints
        The shape of the input array that the kernel will parallelize
        over.

    Returns
    -------
    bpg, tpb : tuple
        Tuple of (blocks_per_grid, threads_per_block).
    """
    if isinstance(shape, int):
        shape = (shape,)

    max_threads = cuda.get_current_device().MAX_THREADS_PER_BLOCK
    # Note: We divide max_threads by 2.0 to leave room for the registers
    threads_per_block = int(ceil(max_threads / 2.0) ** (1.0 / len(shape)))
    tpb = (threads_per_block,) * len(shape)
    bpg = tuple(int(ceil(d / threads_per_block)) for d in shape)
    return bpg, tpb


def calc_cuda_dims(shape):
    threadsperblock = (32, 32)
    blockspergrid = (
        (shape[0] + (threadsperblock[0] - 1)) // threadsperblock[0],
        (shape[1] + (threadsperblock[1] - 1)) // threadsperblock[1]
    )
    return blockspergrid, threadsperblock


def is_cupy_backed(agg: xr.DataArray):
    try:
        return type(agg.data._meta).__module__.split(".")[0] == "cupy"
    except AttributeError:
        return False


def is_dask_cupy(agg: xr.DataArray):
    return isinstance(agg.data, da.Array) and is_cupy_backed(agg)


def not_implemented_func(agg, *args, messages='Not yet implemented.'):
    raise NotImplementedError(messages)


class ArrayTypeFunctionMapping(object):
    def __init__(self, numpy_func, cupy_func, dask_func, dask_cupy_func):
        self.numpy_func = numpy_func
        self.cupy_func = cupy_func
        self.dask_func = dask_func
        self.dask_cupy_func = dask_cupy_func

    def __call__(self, arr):

        # numpy case
        if isinstance(arr.data, np.ndarray):
            return self.numpy_func

        # cupy case
        elif has_cuda_and_cupy() and is_cupy_array(arr.data):
            return self.cupy_func

        # dask + cupy case
        elif has_cuda_and_cupy() and is_dask_cupy(arr):
            return self.dask_cupy_func

        # dask + numpy case
        elif has_dask_array() and isinstance(arr.data, da.Array):
            return self.dask_func

        else:
            raise TypeError("Unsupported Array Type: {}".format(type(arr)))


def _classify_backend(arr):
    """Classify a DataArray's backing storage into one of four buckets that
    line up with ``ArrayTypeFunctionMapping``: ``"numpy"``, ``"cupy"``,
    ``"dask+numpy"``, or ``"dask+cupy"``. Returns ``"unknown"`` for anything
    else so the caller can surface a clear error."""
    data = arr.data
    if has_dask_array() and isinstance(data, da.Array):
        if is_cupy_backed(arr):
            return "dask+cupy"
        return "dask+numpy"
    if is_cupy_array(data):
        return "cupy"
    if isinstance(data, np.ndarray):
        return "numpy"
    return "unknown"


def validate_arrays(*arrays):
    if len(arrays) < 2:
        raise ValueError(
            "validate_arrays() input must contain 2 or more arrays"
        )

    first_array = arrays[0]
    first_backend = _classify_backend(first_array)
    for i in range(1, len(arrays)):

        if not first_array.data.shape == arrays[i].data.shape:
            raise ValueError("input arrays must have equal shapes")

        other_backend = _classify_backend(arrays[i])
        if first_backend != other_backend:
            raise ValueError(
                "input arrays must share the same backend; got "
                "'{}' (array 0) and '{}' (array {})".format(
                    first_backend, other_backend, i
                )
            )

    # ensure dask chunksizes of all arrays are the same
    if has_dask_array() and isinstance(first_array.data, da.Array):
        for i in range(1, len(arrays)):
            if first_array.chunks != arrays[i].chunks:
                arrays[i].data = arrays[i].data.rechunk(first_array.chunks)


def get_xy_range(raster, xdim=None, ydim=None):
    """
    Compute xrange and yrange for input `raster`

    Parameters
    ----------
    raster: xarray.DataArray
    xdim: str, default = None
        Name of the x coordinate dimension in input `raster`.
        If not provided, assume xdim is `raster.dims[-1]`
    ydim: str, default = None
        Name of the y coordinate dimension in input `raster`
        If not provided, assume ydim is `raturns
    ----------
    xrange, yrange
        Tuple of tuples: (x, y-range).
        xrange: tuple of (xmin, xmax)
        yrange: tuple of (ymin, ymax)
    """

    if ydim is None:
        ydim = raster.dims[-2]
    if xdim is None:
        xdim = raster.dims[-1]

    xmin = raster[xdim].min().item()
    xmax = raster[xdim].max().item()
    ymin = raster[ydim].min().item()
    ymax = raster[ydim].max().item()

    xrange = (xmin, xmax)
    yrange = (ymin, ymax)

    return xrange, yrange


def calc_res(raster, xdim=None, ydim=None):
    """
    Calculate the resolution of xarray.DataArray raster and return it
    as thetwo-tuple (xres, yres).

    Parameters
    ----------
    raster: xr.DataArray
        Input raster.
    xdim: str, default = None
        Name of the x coordinate dimension in input `raster`.
        If not provided, assume xdim is `raster.dims[-1]`
    ydim: str, default = None
        Name of the y coordinate dimension in input `raster`
        If not provided, assume ydim is `raster.dims[-2]`

    Returns
    -------
    xres, yres: tuple
        Tuple of (x-resolution, y-resolution).
    """

    if ydim is None:
        ydim = raster.dims[-2]
    if xdim is None:
        xdim = raster.dims[-1]

    h, w = raster.shape[-2:]
    xrange, yrange = get_xy_range(raster, xdim, ydim)
    xres = (xrange[-1] - xrange[0]) / (w - 1)
    yres = (yrange[-1] - yrange[0]) / (h - 1)

    _warn_if_irregular_spacing(raster, xdim, xres, "x")
    _warn_if_irregular_spacing(raster, ydim, yres, "y")

    return xres, yres


def _warn_if_irregular_spacing(raster, dim, res, axis_label):
    """Warn when a 1-D coordinate on `dim` is not evenly spaced.

    `calc_res` reduces the coordinate to a single average cell size
    (full span divided by ``n - 1``). On an irregular grid that average
    misrepresents every cell, and the caller gets no signal. Emit a
    ``UserWarning`` so the averaging is visible and point at
    ``attrs['res']`` for an explicit override (which
    ``get_dataarray_resolution`` honors before it reaches `calc_res`).
    """
    coord = raster.coords.get(dim, None)
    # A 2-point axis has a single step that always equals the averaged
    # step, so it cannot be "irregular"; only check axes with >= 3 points.
    if coord is None or coord.ndim != 1 or coord.size < 3:
        return

    values = np.asarray(coord.values)
    if not np.issubdtype(values.dtype, np.number):
        return

    diffs = np.diff(values)
    if not np.all(np.isfinite(diffs)):
        return

    # Compare each step magnitude against the averaged step. `res` comes
    # from min/max span so it is non-negative regardless of axis
    # direction; a descending (north-up) axis has negative diffs but is
    # still regular, so compare absolute values. The relative tolerance
    # keeps floating-point jitter from tripping the warning.
    if np.allclose(np.abs(diffs), abs(res), rtol=1e-5, atol=0):
        return

    warnings.warn(
        f"xrspatial: '{dim}' coordinate is not evenly spaced; "
        f"using an averaged {axis_label}-resolution of {res}. "
        "Per-cell spacing varies, so distance-based results may be "
        "inaccurate. Set attrs['res'] to an explicit resolution to "
        "silence this warning.",
        UserWarning,
        stacklevel=3,
    )


def get_dataarray_resolution(
    agg: xr.DataArray,
    xdim: str = None,
    ydim: str = None,
):
    """
    Calculate resolution of xarray.DataArray.

    Parameters
    ----------
    agg: xarray.DataArray
        Input raster.
    xdim: str, default = None
        Name of the x coordinate dimension in input `raster`.
        If not provided, assume xdim is `raster.dims[-1]`
    ydim: str, default = None
        Name of the y coordinate dimension in input `raster`
        If not provided, assume ydim is `raster.dims[-2]`

    Returns
    -------
    cellsize_x, cellsize_y: tuple
        Tuple of (x cell size, y cell size).
    """

    # get cellsize out from 'res' attribute
    try:
        cellsize = agg.attrs.get("res")
        if (
            isinstance(cellsize, (tuple, np.ndarray, list))
            and len(cellsize) == 2
            and isinstance(cellsize[0], (int, float))
            and isinstance(cellsize[1], (int, float))
        ):
            cellsize_x, cellsize_y = cellsize
        elif isinstance(cellsize, (int, float)):
            cellsize_x = cellsize
            cellsize_y = cellsize
        else:
            cellsize_x, cellsize_y = calc_res(agg, xdim, ydim)

    except Exception:
        cellsize_x, cellsize_y = calc_res(agg, xdim, ydim)

    return cellsize_x, cellsize_y


def lnglat_to_meters(longitude, latitude):
    """
    Projects the given (longitude, latitude) values into Web Mercator
    coordinates (meters East of Greenwich and meters North of the
    Equator).

    Longitude and latitude can be provided as scalars, Pandas columns,
    or Numpy arrays, and will be returned in the same form.  Lists
    or tuples will be converted to Numpy arrays.

    Parameters
    ----------
    latitude: float
        Input latitude.
    longitude: float
        Input longitude.

    Returns
    -------
    easting, northing : tuple
        Tuple of (easting, northing).

    Examples
    --------
    .. sourcecode:: python

        >>> easting, northing = lnglat_to_meters(-40.71,74)
        >>> easting, northing = lnglat_to_meters(np.array([-74]),
        >>>                                      np.array([40.71]))
        >>> df = pandas.DataFrame(dict(longitude=np.array([-74]),
        >>>                            latitude=np.array([40.71])))
        >>> df.loc[:, 'longitude'], df.loc[:, 'latitude'] = lnglat_to_meters(
        >>>     df.longitude, df.latitude)
    """
    if isinstance(longitude, (list, tuple)):
        longitude = np.array(longitude)
    if isinstance(latitude, (list, tuple)):
        latitude = np.array(latitude)

    origin_shift = np.pi * 6378137
    easting = longitude * origin_shift / 180.0
    northing = np.log(
        np.tan((90 + latitude) * np.pi / 360.0)
        ) * origin_shift / np.pi
    return (easting, northing)


def height_implied_by_aspect_ratio(W, X, Y):
    """
    Utility function for calculating height (in pixels) which is implied
    by a width, x-range, and y-range. Simple ratios are used to maintain
    aspect ratio.

    Parameters
    ----------
    W: int
        Width in pixel.
    X: tuple
        X-range in data units.
    Y: tuple
        X-range in data units.

    Returns
    -------
    height : int
        height in pixels

    Examples
    --------
    .. sourcecode:: python

        >>> plot_width = 1000
        >>> x_range = (0,35
        >>> y_range = (0, 70)
        >>> plot_height = height_implied_by_aspect_ratio(
                plot_width,
                x_range,
                y_range,
            )
    """
    return int((W * (Y[1] - Y[0])) / (X[1] - X[0]))


def bands_to_img(r, g, b, nodata=1):
    from PIL import Image

    h, w = r.shape
    data = np.zeros((h, w, 4), dtype=np.uint8)
    data[:, :, 0] = (r).astype(np.uint8)
    data[:, :, 1] = (g).astype(np.uint8)
    data[:, :, 2] = (b).astype(np.uint8)
    a = np.where(np.logical_or(np.isnan(r), r <= nodata), 0, 255)
    data[:, :, 3] = a.astype(np.uint8)
    return Image.fromarray(data, "RGBA")


def canvas_like(
    raster,
    width=512,
    height=None,
    x_range=None,
    y_range=None,
    **kwargs
):

    """
    Resample a xarray.DataArray by canvas width and bounds.
    Height of the resampled raster is implied from the canvas width
    using aspect ratio of original raster.

    This function uses of datashader.Canvas.raster internally.
    Most of the docstrings are copied from Datashader.

    Handles 2D or 3D xarray.DataArray, assuming that the last two
    array dimensions are the y-axis and x-axis that are to be
    resampled. If a 3D array is supplied a layer may be specified
    to resample to select the layer along the first dimension to
    resample.

    Parameters
    ----------
    raster : xarray.DataArray
        2D or 3D labeled data array.
    layer : float, optional
        For a 3D array, value along the z dimension.
    width : int, default=512
        Width of the output aggregate in pixels.
    height : int, default=None
        Height of the output aggregate in pixels.
        If not provided, height will be implied from `width`
        using aspect ratio of input raster.
    x_range : tuple of int, optional
        A tuple representing the bounds inclusive space ``[min, max]``
        along the x-axis.
    y_range : tuple of int, optional
        A tuple representing the bounds inclusive space ``[min, max]``
        along the y-axis.

    References
    ----------
        - https://datashader.org/_modules/datashader/core.html#Canvas
    """

    # get ranges
    if x_range is None:
        x_range = (
            raster.coords["x"].min().item(),
            raster.coords["x"].max().item()
        )
    if y_range is None:
        y_range = (
            raster.coords["y"].min().item(),
            raster.coords["y"].max().item()
        )

    if height is None:
        # set width and height
        height = height_implied_by_aspect_ratio(width, x_range, y_range)

    try:
        import datashader as ds
    except ImportError:
        raise ImportError(
            "canvas_like requires datashader: pip install datashader"
        )

    cvs = ds.Canvas(
        plot_width=width, plot_height=height, x_range=x_range, y_range=y_range
    )
    out = cvs.raster(raster, **kwargs)

    return out


def _hex_to_rgb(c):
    """Convert a hex color string (e.g. '#ff0000' or 'ff0000') to (r, g, b)."""
    c = c.lstrip('#')
    return int(c[0:2], 16), int(c[2:4], 16), int(c[4:6], 16)


def color_values(agg, color_key, alpha=255):
    from PIL import Image

    def _convert_color(c):
        r, g, b = _hex_to_rgb(c)
        return np.array([r, g, b, alpha]).astype(np.uint8).view(np.uint32)[0]

    _converted_colors = {k: _convert_color(v) for k, v in color_key.items()}
    f = np.vectorize(lambda v: _converted_colors.get(v, 0))
    return Image.fromarray(f(agg.data).astype(np.uint32).view(np.uint8).reshape(
        agg.data.shape + (4,)), "RGBA")


def _infer_coord_unit_type(coord: xr.DataArray, cellsize: float) -> str:
    """
    Heuristic to classify a spatial coordinate axis as:
    - 'degrees'
    - 'linear'   (meters/feet/etc)
    - 'unknown'

    Parameters
    ----------
    coord : xr.DataArray
        1D coordinate variable (x or y).
    cellsize : float
        Mean spacing along this coordinate.

    Returns
    -------
    str
    """
    units = str(coord.attrs.get("units", "")).lower()

    # 1) Explicit units, if present
    if "degree" in units or units in ("deg", "degrees"):
        return "degrees"
    if units in ("m", "meter", "metre", "meters", "metres",
                 "km", "kilometer", "kilometre", "kilometers", "kilometres",
                 "ft", "foot", "feet"):
        return "linear"

    # 2) Numeric heuristics (very conservative)
    vals = coord.values
    if vals.size < 2 or not np.issubdtype(vals.dtype, np.number):
        return "unknown"

    vmin = float(np.nanmin(vals))
    vmax = float(np.nanmax(vals))
    span = abs(vmax - vmin)
    dx = abs(float(cellsize))

    # Typical global geographic axes: span <= 360, spacing ~1e-5–0.5 deg
    if -360.0 <= vmin <= 360.0 and -360.0 <= vmax <= 360.0:
        if 1e-5 <= dx <= 0.5:
            return "degrees"

    # Typical projected axes in meters: span >> 1, spacing > ~0.1
    # (e.g. UTM / national grids)
    if span > 1000.0 and dx >= 0.1:
        return "linear"

    return "unknown"


def _infer_vertical_unit_type(agg):
    units = str(agg.attrs.get("units", "")).lower()

    # Cheap / reliable first
    if any(k in units for k in ("degree", "deg")) or "rad" in units:
        return "angle"
    if units in ("m", "meter", "metre", "meters", "metres",
                 "km", "kilometer", "kilometre", "kilometers", "kilometres",
                 "ft", "foot", "feet"):
        return "elevation"

    # Numeric fallback: sample only (never full compute)
    data = agg.data
    try:
        vmin, vmax = _sample_windows_min_max(data, max_window_elems=65536, windows=5)
    except Exception:
        return "unknown"

    if not np.isfinite(vmin) or not np.isfinite(vmax):
        return "unknown"

    span = vmax - vmin

    # Elevation-ish heuristic
    if 10.0 <= span <= 20000.0 and vmin > -500.0:
        return "elevation"

    # Angle-ish heuristic
    if -360.0 <= vmin <= 360.0 and -360.0 <= vmax <= 360.0 and span <= 720.0:
        return "angle"

    return "unknown"


def warn_if_unit_mismatch(agg: xr.DataArray) -> None:
    """
    Heuristic check for horizontal vs vertical unit mismatch.

    Intended to catch the common case of:
    - coordinates in degrees (lon/lat)
    - elevation values in meters/feet

    Emits a UserWarning if a likely mismatch is detected.
    """
    try:
        cellsize_x, cellsize_y = get_dataarray_resolution(agg)
    except Exception:
        # If we can't even get a resolution, we also can't say much
        return

    # pick "x" and "y" coords in a generic way:
    #   - typically dims are ('y', 'x') or ('lat', 'lon')
    #   - fall back to last two dims
    if len(agg.dims) < 2:
        return

    dim_y, dim_x = agg.dims[-2], agg.dims[-1]
    coord_x = agg.coords.get(dim_x, None)
    coord_y = agg.coords.get(dim_y, None)

    if coord_x is None or coord_y is None:
        # Can't infer spatial types without coords
        return

    horiz_x = _infer_coord_unit_type(coord_x, cellsize_x)
    horiz_y = _infer_coord_unit_type(coord_y, cellsize_y)
    vert = _infer_vertical_unit_type(agg)

    horiz_types = {horiz_x, horiz_y} - {"unknown"}

    # Only act if we have some signal about horizontal AND vertical
    if not horiz_types or vert == "unknown":
        return

    # If any axis looks like degrees and vertical looks like elevation,
    # it's almost certainly "lat/lon degrees + meter elevations"
    if "degrees" in horiz_types and vert == "elevation":
        warnings.warn(
            "xrspatial: input DataArray appears to have coordinates in degrees "
            "but elevation values in a linear unit (e.g. meters/feet). "
            "Slope/aspect operations expect horizontal distances in the same "
            "units as vertical. Consider reprojecting to a projected CRS with "
            "meter-based coordinates before calling `slope`.",
            UserWarning,
        )


# ---------- Z-unit conversion for geodesic methods ----------
Z_UNITS = {
    'meter': 1.0, 'meters': 1.0, 'm': 1.0,
    'foot': 0.3048, 'feet': 0.3048, 'ft': 0.3048,
    'kilometer': 1000.0, 'kilometers': 1000.0, 'km': 1000.0,
    'mile': 1609.344, 'miles': 1609.344, 'mi': 1609.344,
}


# ---------- Lat/lon coordinate extraction for geodesic methods ----------
# Known dimension / coordinate names (lower-cased for matching)
_LAT_NAMES = {'lat', 'latitude', 'y'}
_LON_NAMES = {'lon', 'longitude', 'x'}
# Names that unambiguously mean geographic lat/lon. These take precedence
# over a numeric dimension coord so a curvilinear raster with numeric y/x
# index coords plus real lat/lon coords resolves to the lat/lon coords.
_EXPLICIT_LAT_NAMES = {'lat', 'latitude'}
_EXPLICIT_LON_NAMES = {'lon', 'longitude'}


def _extract_latlon_coords(agg: xr.DataArray):
    """
    Extract 2-D latitude and longitude arrays from a DataArray.

    Supports:
    - 1-D coordinates on the last two dims (regular geographic grid).
    - 2-D coordinates that vary per cell (curvilinear grid).

    Returns
    -------
    lat_2d, lon_2d : numpy.ndarray
        Always 2-D float64 numpy arrays of shape ``(H, W)``.

    Raises
    ------
    ValueError
        If coordinates are missing, non-numeric, or outside geographic
        ranges (lat not in [-90, 90], lon not in [-180, 360]).
    """
    if agg.ndim < 2:
        raise ValueError(
            "geodesic method requires a 2-D DataArray, "
            f"got {agg.ndim}-D"
        )

    dim_y, dim_x = agg.dims[-2], agg.dims[-1]

    # --- locate lat coordinate ---
    lat_coord = _find_coord(agg, dim_y, _LAT_NAMES, _EXPLICIT_LAT_NAMES,
                            'latitude')
    # --- locate lon coordinate ---
    lon_coord = _find_coord(agg, dim_x, _LON_NAMES, _EXPLICIT_LON_NAMES,
                            'longitude')

    lat_vals = np.asarray(lat_coord.values, dtype=np.float64)
    lon_vals = np.asarray(lon_coord.values, dtype=np.float64)

    # Build 2-D arrays
    if lat_vals.ndim == 1 and lon_vals.ndim == 1:
        # Regular grid: broadcast to 2-D
        lat_2d = np.broadcast_to(lat_vals[:, np.newaxis],
                                 (agg.sizes[dim_y], agg.sizes[dim_x]))
        lon_2d = np.broadcast_to(lon_vals[np.newaxis, :],
                                 (agg.sizes[dim_y], agg.sizes[dim_x]))
    elif lat_vals.ndim == 2 and lon_vals.ndim == 2:
        lat_2d = lat_vals
        lon_2d = lon_vals
    else:
        raise ValueError(
            f"lat/lon coordinates must be both 1-D or both 2-D, "
            f"got lat={lat_vals.ndim}-D and lon={lon_vals.ndim}-D"
        )

    # --- validate ranges ---
    _validate_geographic_range(lat_2d, lon_2d)

    return lat_2d, lon_2d


def _find_coord(agg, dim_name, known_names, explicit_names, label):
    """Find a coordinate matching *dim_name* or one of *known_names*.

    A coordinate whose name is unambiguously geographic (*explicit_names*,
    e.g. ``lat``/``longitude``) is preferred over the dimension coord. This
    keeps a curvilinear raster with numeric ``y``/``x`` index coords plus
    real lat/lon coords from silently using the pixel indices as lat/lon.
    If several explicit names are present (e.g. both ``lat`` and
    ``latitude``), the first one in coord order wins.
    """
    # 1) Prefer an explicitly named geographic coordinate (lat/lon).
    for name in agg.coords:
        if str(name).lower() in explicit_names:
            coord = agg.coords[name]
            if np.issubdtype(coord.dtype, np.number):
                return coord

    # 2) Fall back to the dimension name directly.
    if dim_name in agg.coords:
        coord = agg.coords[dim_name]
        if np.issubdtype(coord.dtype, np.number):
            return coord

    # 3) Scan all coords for any other known name (e.g. y/x).
    for name in agg.coords:
        if str(name).lower() in known_names:
            coord = agg.coords[name]
            if np.issubdtype(coord.dtype, np.number):
                return coord

    raise ValueError(
        f"geodesic method requires {label} coordinates on the DataArray. "
        f"No numeric coordinate found for dim '{dim_name}' or any of "
        f"{sorted(known_names)}."
    )


def _validate_geographic_range(lat_2d, lon_2d):
    """Raise ValueError if lat/lon values look non-geographic."""
    lat_min = np.nanmin(lat_2d)
    lat_max = np.nanmax(lat_2d)
    lon_min = np.nanmin(lon_2d)
    lon_max = np.nanmax(lon_2d)

    if lat_min < -90 or lat_max > 90:
        raise ValueError(
            f"Latitude values must be in [-90, 90], "
            f"got [{lat_min}, {lat_max}]. "
            f"Are your coordinates in a projected CRS?"
        )
    if lon_min < -180 or lon_max > 360:
        raise ValueError(
            f"Longitude values must be in [-180, 360], "
            f"got [{lon_min}, {lon_max}]. "
            f"Are your coordinates in a projected CRS?"
        )
    lat_span = lat_max - lat_min
    lon_span = lon_max - lon_min
    if lat_span > 180 or lon_span > 360:
        raise ValueError(
            f"Coordinate span too large for geographic coordinates "
            f"(lat span={lat_span}, lon span={lon_span}). "
            f"Are your coordinates in a projected CRS?"
        )


def _to_float_scalar(x) -> float:
    """Convert numpy/cupy scalar or 0-d array to python float safely."""
    if cupy is not None:
        # cupy.ndarray scalar
        if isinstance(x, cupy.ndarray):
            return float(cupy.asnumpy(x).item())
        # cupy scalar type
        if x.__class__.__module__.startswith("cupy") and hasattr(x, "item"):
            return float(x.item())

    if hasattr(x, "item"):
        return float(x.item())
    return float(x)


def _sample_windows_min_max(
    data,
    *,
    max_window_elems: int = 65536,   # e.g. 256x256
    windows: int = 5,                # corners + center default
) -> tuple[float, float]:
    """
    Estimate (nanmin, nanmax) from a small sample of windows.

    Works for numpy, cupy, dask+numpy, dask+cupy. Only computes on the sampled
    windows, not the full array.
    """
    # Normalize to last-2D sampling (y,x). For higher dims, sample first index.
    if hasattr(data, "ndim") and data.ndim >= 3:
        prefix = (0,) * (data.ndim - 2)
    else:
        prefix = ()

    # Determine y/x sizes
    shape = data.shape
    ny, nx = shape[-2], shape[-1]

    if ny == 0 or nx == 0:
        return np.nan, np.nan

    # Choose a square-ish window size bounded by array shape
    w = int(np.sqrt(max_window_elems))
    w = max(1, min(w, ny, nx))

    # Define window anchor positions: (top-left), (top-right), (bottom-left), (bottom-right), (center)
    anchors = [
        (0, 0),
        (0, max(0, nx - w)),
        (max(0, ny - w), 0),
        (max(0, ny - w), max(0, nx - w)),
    ]
    if windows >= 5:
        anchors.append((max(0, ny // 2 - w // 2), max(0, nx // 2 - w // 2)))

    # If windows > 5, sprinkle additional evenly-spaced anchors (optional)
    if windows > 5:
        extra = windows - 5
        ys = np.linspace(0, max(0, ny - w), extra + 2, dtype=int)[1:-1]
        xs = np.linspace(0, max(0, nx - w), extra + 2, dtype=int)[1:-1]
        for y0, x0 in zip(ys, xs):
            anchors.append((int(y0), int(x0)))

    # Reduce min/max across sampled windows
    mins = []
    maxs = []

    for y0, x0 in anchors:
        sl = prefix + (slice(y0, y0 + w), slice(x0, x0 + w))
        win = data[sl]

        if da is not None and isinstance(win, da.Array):
            # Compute scalars only on this window
            mins.append(da.nanmin(win))
            maxs.append(da.nanmax(win))
        elif cupy is not None and isinstance(win, cupy.ndarray):
            mins.append(cupy.nanmin(win))
            maxs.append(cupy.nanmax(win))
        else:
            mins.append(np.nanmin(win))
            maxs.append(np.nanmax(win))

    # Finalize: if dask, compute the scalar graph now (still tiny)
    if da is not None and any(isinstance(m, da.Array) for m in mins):
        mn = da.nanmin(da.stack(mins)).compute()
        mx = da.nanmax(da.stack(maxs)).compute()
        return _to_float_scalar(mn), _to_float_scalar(mx)

    # If cupy scalars, convert safely
    if cupy is not None and (any(isinstance(m, cupy.ndarray) for m in mins) or
                             any(getattr(m.__class__, "__module__", "").startswith("cupy") for m in mins)):
        mn = mins[0]
        mx = maxs[0]
        # reduce on device
        for m in mins[1:]:
            mn = cupy.minimum(mn, m)
        for m in maxs[1:]:
            mx = cupy.maximum(mx, m)
        return _to_float_scalar(mn), _to_float_scalar(mx)

    # numpy scalars
    return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float)))


def _no_shuffle_chunks(chunks, dtype, dims, target_mb):
    """Compute target chunk dict that is an exact multiple of *chunks*.

    Returns a ``{dim: size}`` dict, or ``None`` when the current
    chunks already meet or exceed the target.
    """
    base = tuple(c[0] for c in chunks)

    current_bytes = dtype.itemsize
    for b in base:
        current_bytes *= b

    target_bytes = target_mb * 1024 * 1024

    if current_bytes >= target_bytes:
        return None

    ndim = len(base)
    ratio = target_bytes / current_bytes
    multiplier = max(1, int(ratio ** (1.0 / ndim)))

    if multiplier <= 1:
        return None

    return {dim: b * multiplier for dim, b in zip(dims, base)}


[docs] def rechunk_no_shuffle(agg, target_mb=128): """Rechunk a dask-backed DataArray or Dataset without triggering a shuffle. Computes an integer multiplier per dimension so that each new chunk is an exact multiple of the original chunk size. This lets dask merge whole source chunks in-place instead of splitting and recombining partial blocks (which is effectively a shuffle). Parameters ---------- agg : xr.DataArray or xr.Dataset Input raster(s). If not backed by a dask array the input is returned unchanged. For Datasets, each variable is rechunked independently. target_mb : int or float Target chunk size in megabytes. The actual chunk size will be the closest multiple of the source chunk that does not exceed this target. Default 128. Returns ------- xr.DataArray or xr.Dataset Rechunked object. Coordinates and attributes are preserved. Raises ------ TypeError If *agg* is not an ``xr.DataArray`` or ``xr.Dataset``. ValueError If *target_mb* is not positive. Examples -------- >>> import dask.array as da >>> import xarray as xr >>> arr = xr.DataArray(da.zeros((4096, 4096), chunks=256)) >>> big = rechunk_no_shuffle(arr, target_mb=64) >>> big.chunks # multiples of 256 """ if isinstance(agg, xr.Dataset): return _rechunk_dataset_no_shuffle(agg, target_mb) if not isinstance(agg, xr.DataArray): raise TypeError( f"rechunk_no_shuffle(): expected xr.DataArray or xr.Dataset, " f"got {type(agg).__name__}" ) if target_mb <= 0: raise ValueError( f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}" ) if not has_dask_array() or not isinstance(agg.data, da.Array): return agg new_chunks = _no_shuffle_chunks( agg.chunks, agg.dtype, agg.dims, target_mb, ) if new_chunks is None: return agg return agg.chunk(new_chunks)
def _rechunk_dataset_no_shuffle(ds, target_mb): """Rechunk every variable in a Dataset without triggering a shuffle.""" if target_mb <= 0: raise ValueError( f"rechunk_no_shuffle(): target_mb must be > 0, got {target_mb}" ) if not has_dask_array(): return ds # Compute target chunks from the first dask-backed variable. # This assumes all variables share the same chunk layout and dtype; # for mixed-dtype Datasets the budget may overshoot on smaller types. new_chunks = None for var in ds.data_vars.values(): if isinstance(var.data, da.Array): new_chunks = _no_shuffle_chunks( var.chunks, var.dtype, var.dims, target_mb, ) break if new_chunks is None: return ds return ds.chunk(new_chunks) def _normalize_depth(depth, ndim): """Normalize depth to a dict {axis: int}. Accepts int, tuple, or dict. Validates completeness and non-negativity. """ if isinstance(depth, dict): expected = set(range(ndim)) got = set(depth.keys()) missing = expected - got extra = got - expected if missing: raise ValueError( f"_normalize_depth: missing axes {sorted(missing)} " f"for ndim={ndim}" ) if extra: raise ValueError( f"_normalize_depth: extra axes {sorted(extra)} " f"for ndim={ndim}" ) for v in depth.values(): if v < 0: raise ValueError( f"_normalize_depth: depth must be non-negative, got {v}" ) return dict(depth) if isinstance(depth, int): if depth < 0: raise ValueError( f"_normalize_depth: depth must be non-negative, got {depth}" ) return {ax: depth for ax in range(ndim)} if isinstance(depth, tuple): if len(depth) != ndim: raise ValueError( f"_normalize_depth: tuple length {len(depth)} != ndim {ndim}" ) for v in depth: if v < 0: raise ValueError( f"_normalize_depth: depth must be non-negative, got {v}" ) return {ax: d for ax, d in enumerate(depth)} raise TypeError( f"_normalize_depth: expected int, tuple, or dict, got {type(depth).__name__}" ) def _pad_nan(data, depth): """Pad a 2-D numpy or cupy array with NaN on each side. Parameters ---------- data : numpy or cupy array depth : tuple of int ``(d0, d1)`` cells to pad per axis. """ pad_width = tuple((d, d) for d in depth) if is_cupy_array(data): if np.issubdtype(data.dtype, np.integer): data = data.astype(cupy.float64) out = cupy.pad(data, pad_width, mode='constant', constant_values=np.nan) else: # Promote integer dtypes so NaN fill works if np.issubdtype(data.dtype, np.integer): data = data.astype(np.float64) out = np.pad(data, pad_width, mode='constant', constant_values=np.nan) return out
[docs] def fused_overlap(agg, *stages, boundary='nan'): """Run multiple overlap operations in a single map_overlap call. Each stage is a ``(func, depth)`` pair. ``func`` receives a padded chunk and returns the unpadded interior result. Stages are fused into one ``map_overlap`` call with the sum of all depths, producing one blockwise graph layer instead of N. Parameters ---------- agg : xr.DataArray Input raster. *stages : tuple of (callable, depth) Each ``func`` takes array ``(H+2*d, W+2*d)`` -> ``(H, W)``. ``depth`` is int, tuple, or dict. boundary : str Must be ``'nan'``. Returns ------- xr.DataArray """ if not isinstance(agg, xr.DataArray): raise TypeError( f"fused_overlap(): expected xr.DataArray, " f"got {type(agg).__name__}" ) if not stages: raise ValueError("fused_overlap(): need at least one stage") if boundary != 'nan': raise ValueError( f"fused_overlap(): boundary must be 'nan', got {boundary!r}" ) ndim = agg.ndim # Normalize and sum depths stage_depths = [_normalize_depth(d, ndim) for _, d in stages] total_depth = {ax: sum(sd[ax] for sd in stage_depths) for ax in range(ndim)} # --- non-dask path --- if not has_dask_array() or not isinstance(agg.data, da.Array): result = agg.data for i, (func, _) in enumerate(stages): depth_tuple = tuple(stage_depths[i][ax] for ax in range(ndim)) padded = _pad_nan(result, depth_tuple) result = func(padded) return agg.copy(data=result) # --- dask path --- # Validate chunk sizes for ax, d in total_depth.items(): for cs in agg.chunks[ax]: if cs < d: raise ValueError( f"Chunk size {cs} on axis {ax} is smaller than " f"total depth {d}. Rechunk first." ) funcs = [f for f, _ in stages] def _fused_wrapper(block): result = block for func in funcs: result = func(result) return result out = agg.data.map_overlap( _fused_wrapper, depth=total_depth, boundary=np.nan, trim=False, meta=np.array(()), ) return agg.copy(data=out)
[docs] def multi_overlap(agg, func, n_outputs, depth, boundary='nan', dtype=None): """Run a multi-output kernel via a single overlap + map_blocks call. ``func`` receives a padded 2-D chunk and returns ``(n_outputs, H, W)`` -- the unpadded interior for each output band. The result is a 3-D DataArray with a leading ``band`` dimension. Parameters ---------- agg : xr.DataArray 2-D input raster. func : callable ``(H+2*dy, W+2*dx) -> (n_outputs, H, W)`` n_outputs : int Number of output bands (>= 1). depth : int or tuple of int Per-axis overlap (>= 1 on each axis). boundary : str Boundary mode: 'nan', 'nearest', 'reflect', or 'wrap'. dtype : numpy dtype, optional Output dtype. Defaults to input dtype. Returns ------- xr.DataArray Shape ``(n_outputs, H, W)`` with ``band`` leading dimension. """ if not isinstance(agg, xr.DataArray): raise TypeError( f"multi_overlap(): expected xr.DataArray, " f"got {type(agg).__name__}" ) if agg.ndim != 2: raise ValueError( f"multi_overlap(): input must be 2-D, got {agg.ndim}-D" ) if n_outputs < 1: raise ValueError( f"multi_overlap(): n_outputs must be >= 1, got {n_outputs}" ) _validate_boundary(boundary) depth_dict = _normalize_depth(depth, agg.ndim) for ax, d in depth_dict.items(): if d < 1: raise ValueError( f"multi_overlap(): depth must be >= 1, got {d} on axis {ax}" ) dtype = dtype or agg.dtype # --- non-dask path --- if not has_dask_array() or not isinstance(agg.data, da.Array): if boundary == 'nan': depth_tuple = tuple(depth_dict[ax] for ax in range(agg.ndim)) padded = _pad_nan(agg.data, depth_tuple) else: depth_tuple = tuple(depth_dict[ax] for ax in range(agg.ndim)) padded = _pad_array(agg.data, depth_tuple, boundary) result_data = func(padded).astype(dtype) return xr.DataArray( result_data, dims=['band'] + list(agg.dims), coords=agg.coords, attrs=agg.attrs, ) # --- dask path --- import dask.array.overlap as _dask_overlap boundary_val = _boundary_to_dask(boundary, is_cupy=is_cupy_backed(agg)) # Validate chunk sizes for ax, d in depth_dict.items(): for cs in agg.chunks[ax]: if cs < d: raise ValueError( f"Chunk size {cs} on axis {ax} is smaller than " f"depth {d}. Rechunk first." ) # Step 1: pad with overlap padded = _dask_overlap.overlap( agg.data, depth=depth_dict, boundary=boundary_val ) # Step 2: map_blocks -- func returns (n_outputs, H, W) per block out = da.map_blocks( func, padded, dtype=dtype, new_axis=0, chunks=((n_outputs,),) + agg.data.chunks, ) return xr.DataArray( out, dims=['band'] + list(agg.dims), coords=agg.coords, attrs=agg.attrs, )