Source code for xrspatial.hydro.sink_d8

"""Sink identification: find and label depression cells in a D8 flow direction grid.

Identifies cells with direction code 0 (pit/flat with no downhill neighbor)
and labels connected groups using 8-connected BFS.
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from numba import cuda

try:
    import cupy
except ImportError:
    class cupy:  # type: ignore[no-redef]
        ndarray = False

try:
    import dask.array as da
except ImportError:
    da = None

from xrspatial.dataset_support import supports_dataset
from xrspatial.utils import (_validate_raster, cuda_args, has_cuda_and_cupy, is_cupy_array,
                             is_dask_cupy, ngjit)

# =====================================================================
# Memory guards
# =====================================================================
#
# CPU peak working set per pixel for ``_sink_cpu``:
#   labels  : float64 -> 8
#   queue_r : int64   -> 8
#   queue_c : int64   -> 8
# Total ~24 bytes/pixel.  The caller-provided ``flow_dir`` array already
# lives in RAM before the kernel runs and is not double-counted here.
_BYTES_PER_PIXEL = 24

# GPU peak working set per pixel for ``_sink_cupy``:
#   labels : float64 -> 8
# Total ~8 bytes/pixel.  ``flow_dir_data`` already lives on the device
# before the kernel runs and is not double-counted here.
_GPU_BYTES_PER_PIXEL = 8


def _available_memory_bytes():
    """Best-effort estimate of available host memory in bytes."""
    try:
        with open('/proc/meminfo', 'r') as f:
            for line in f:
                if line.startswith('MemAvailable:'):
                    return int(line.split()[1]) * 1024  # kB -> bytes
    except (OSError, ValueError, IndexError):
        pass
    try:
        import psutil
        return psutil.virtual_memory().available
    except (ImportError, AttributeError):
        pass
    return 2 * 1024 ** 3


def _available_gpu_memory_bytes():
    """Best-effort estimate of free GPU memory in bytes.

    Returns 0 if CuPy / CUDA is unavailable or the query fails -- callers
    use that as a sentinel meaning "no GPU info, skip the guard".
    """
    try:
        import cupy as _cp
        free, _total = _cp.cuda.runtime.memGetInfo()
        return int(free)
    except Exception:
        return 0


def _check_memory(height, width):
    """Raise MemoryError if the BFS kernel would exceed 50% of RAM."""
    required = int(height) * int(width) * _BYTES_PER_PIXEL
    available = _available_memory_bytes()
    if required > 0.5 * available:
        raise MemoryError(
            f"sink_d8 on a {height}x{width} grid requires "
            f"~{required / 1e9:.1f} GB of working memory but only "
            f"~{available / 1e9:.1f} GB is available.  Use a "
            f"dask-backed DataArray for out-of-core processing."
        )


def _check_gpu_memory(height, width):
    """Raise MemoryError if the CuPy kernel would exceed 50% of free GPU RAM.

    Skips the check (returns silently) when ``_available_gpu_memory_bytes``
    cannot determine the free memory -- e.g. on hosts without CUDA, where
    the kernel will fail at the cupy.asarray boundary anyway.
    """
    available = _available_gpu_memory_bytes()
    if available <= 0:
        return
    required = int(height) * int(width) * _GPU_BYTES_PER_PIXEL
    if required > 0.5 * available:
        raise MemoryError(
            f"sink_d8 on a {height}x{width} grid requires "
            f"~{required / 1e9:.1f} GB of GPU working memory but only "
            f"~{available / 1e9:.1f} GB is free on the active device.  "
            f"Use a dask+cupy DataArray for out-of-core processing."
        )


# =====================================================================
# CPU kernel
# =====================================================================

@ngjit
def _sink_cpu(flow_dir, h, w, row_off, col_off, total_w):
    """8-connected BFS flood-fill CCL for sink cells (code 0).

    Labels each connected group of code-0 cells with a unique ID
    based on position: (row_off + r) * total_w + (col_off + c) + 1.
    """
    labels = np.empty((h, w), dtype=np.float64)
    labels[:] = np.nan

    dy = np.array([-1, -1, -1, 0, 0, 1, 1, 1])
    dx = np.array([-1, 0, 1, -1, 1, -1, 0, 1])

    queue_r = np.empty(h * w, dtype=np.int64)
    queue_c = np.empty(h * w, dtype=np.int64)

    for r in range(h):
        for c in range(w):
            v = flow_dir[r, c]
            if v != v:  # NaN
                continue
            if v != 0.0:
                continue
            if labels[r, c] == labels[r, c]:  # already labeled
                continue

            label = float((row_off + r) * total_w + (col_off + c) + 1)
            labels[r, c] = label
            head = np.int64(0)
            tail = np.int64(0)
            queue_r[tail] = r
            queue_c[tail] = c
            tail += 1

            while head < tail:
                cr = queue_r[head]
                cc = queue_c[head]
                head += 1

                for k in range(8):
                    nr = cr + dy[k]
                    nc = cc + dx[k]
                    if nr < 0 or nr >= h or nc < 0 or nc >= w:
                        continue
                    nv = flow_dir[nr, nc]
                    if nv != nv:
                        continue
                    if nv != 0.0:
                        continue
                    if labels[nr, nc] == labels[nr, nc]:
                        continue
                    labels[nr, nc] = label
                    queue_r[tail] = nr
                    queue_c[tail] = nc
                    tail += 1

    return labels


# =====================================================================
# GPU kernels
# =====================================================================

@cuda.jit
def _sink_init_gpu(flow_dir, labels, H, W):
    """Pits (code 0) get position-based ID, others get 0."""
    i, j = cuda.grid(2)
    if i >= H or j >= W:
        return
    v = flow_dir[i, j]
    if v != v:  # NaN
        labels[i, j] = 0.0
        return
    if v == 0.0:
        labels[i, j] = float(i * W + j + 1)
    else:
        labels[i, j] = 0.0


@cuda.jit
def _sink_propagate_gpu(labels, changed, H, W):
    """Min-label propagation: each sink cell takes minimum neighbor label."""
    i, j = cuda.grid(2)
    if i >= H or j >= W:
        return

    my_label = labels[i, j]
    if my_label <= 0.0:
        return

    min_label = my_label
    for k in range(8):
        if k == 0:
            dy, dx = -1, -1
        elif k == 1:
            dy, dx = -1, 0
        elif k == 2:
            dy, dx = -1, 1
        elif k == 3:
            dy, dx = 0, -1
        elif k == 4:
            dy, dx = 0, 1
        elif k == 5:
            dy, dx = 1, -1
        elif k == 6:
            dy, dx = 1, 0
        else:
            dy, dx = 1, 1

        ni = i + dy
        nj = j + dx
        if ni < 0 or ni >= H or nj < 0 or nj >= W:
            continue
        nb = labels[ni, nj]
        if nb > 0.0 and nb < min_label:
            min_label = nb

    if min_label < my_label:
        labels[i, j] = min_label
        cuda.atomic.add(changed, 0, 1)


def _sink_cupy(flow_dir_data):
    """GPU driver for sink identification."""
    import cupy as cp

    H, W = flow_dir_data.shape
    flow_dir_f64 = flow_dir_data.astype(cp.float64)
    labels = cp.zeros((H, W), dtype=cp.float64)
    changed = cp.zeros(1, dtype=cp.int32)

    griddim, blockdim = cuda_args((H, W))
    _sink_init_gpu[griddim, blockdim](flow_dir_f64, labels, H, W)

    max_iter = max(H, W)
    for _ in range(max_iter):
        changed[0] = 0
        _sink_propagate_gpu[griddim, blockdim](labels, changed, H, W)
        if int(changed[0]) == 0:
            break

    return cp.where(labels > 0, labels, cp.nan)


# =====================================================================
# Backend wrappers
# =====================================================================

def _run_numpy(data):
    h, w = data.shape
    return _sink_cpu(data.astype(np.float64), h, w, 0, 0, w)


# =====================================================================
# Cross-tile union-find for dask CCL
# =====================================================================
#
# Per-tile CCL produces globally unique IDs but does not merge
# components that span tile boundaries.  After the per-tile pass we
# walk each shared edge, record an equivalence whenever two adjacent
# boundary cells are both sinks, then union and remap labels.

def _uf_find(parent, x):
    """Path-halving find on a dict-backed union-find."""
    while parent[x] != x:
        parent[x] = parent[parent[x]]
        x = parent[x]
    return x


def _uf_union(parent, a, b):
    """Union two label roots; smaller root wins so labels stay deterministic."""
    ra = _uf_find(parent, a)
    rb = _uf_find(parent, b)
    if ra == rb:
        return
    if ra < rb:
        parent[rb] = ra
    else:
        parent[ra] = rb


def _collect_boundary_equivalences(labels_np):
    """Return a list of (label_a, label_b) pairs from interior tile edges.

    *labels_np* is the materialized numpy result of the per-tile CCL pass.
    We scan every interior row/column boundary plus the two diagonal
    pairs (NE/SW and NW/SE) so 8-connectivity is preserved across tiles.

    Pairs where either side is NaN or 0 are skipped.  Pairs with the
    same label on both sides are skipped too.
    """
    pairs = []

    def _scan(a, b):
        # a and b are matched-shape slices; record (la, lb) where both
        # are sink labels (non-NaN, non-zero).
        if a.size == 0:
            return
        valid = ~(np.isnan(a) | np.isnan(b))
        if not valid.any():
            return
        am = a[valid]
        bm = b[valid]
        diff = am != bm
        if not diff.any():
            return
        la = am[diff].astype(np.int64)
        lb = bm[diff].astype(np.int64)
        for i in range(la.size):
            pairs.append((int(la[i]), int(lb[i])))

    # Vertical neighbors (up-down): every row boundary
    _scan(labels_np[:-1, :], labels_np[1:, :])
    # Horizontal neighbors (left-right)
    _scan(labels_np[:, :-1], labels_np[:, 1:])
    # Diagonal NW-SE
    _scan(labels_np[:-1, :-1], labels_np[1:, 1:])
    # Diagonal NE-SW
    _scan(labels_np[:-1, 1:], labels_np[1:, :-1])
    return pairs


def _build_label_remap(labels_np):
    """Build a {label: root_label} mapping for cross-tile sink merges.

    Only labels whose root differs from themselves end up in the dict;
    callers can short-circuit when the result is empty.
    """
    pairs = _collect_boundary_equivalences(labels_np)
    if not pairs:
        return {}

    parent = {}
    for a, b in pairs:
        if a not in parent:
            parent[a] = a
        if b not in parent:
            parent[b] = b
        _uf_union(parent, a, b)

    remap = {}
    for label in list(parent):
        root = _uf_find(parent, label)
        if root != label:
            remap[label] = root
    return remap


def _apply_label_remap(block, remap_keys, remap_vals):
    """Replace each label in *block* with its root from the remap arrays."""
    if remap_keys.size == 0:
        return block
    out = block.copy()
    # np.searchsorted gives O(N log K) lookup which beats a Python dict
    # in the inner loop and is easy to vectorize.
    flat = out.ravel()
    valid = ~np.isnan(flat)
    if not valid.any():
        return out
    vals = flat[valid].astype(np.int64)
    idx = np.searchsorted(remap_keys, vals)
    in_range = idx < remap_keys.size
    hits = np.zeros_like(vals, dtype=bool)
    hits[in_range] = remap_keys[idx[in_range]] == vals[in_range]
    if hits.any():
        new_vals = vals.astype(np.float64)
        new_vals[hits] = remap_vals[idx[hits]].astype(np.float64)
        flat[valid] = new_vals
        out = flat.reshape(block.shape)
    return out


def _merge_cross_tile_labels(labels_da):
    """Merge sink labels across tile boundaries.

    Materializes the per-tile CCL result so we can scan all boundaries,
    runs union-find, and applies the remap lazily via map_blocks.
    """
    # Materialize once to scan boundaries.  CCL is fundamentally a global
    # operation so we can't avoid touching every cell; the per-tile pass
    # already streamed.
    labels_np = labels_da.compute()
    remap = _build_label_remap(labels_np)
    if not remap:
        # Nothing to merge — wrap the materialized result back into dask
        # so the caller still gets a dask array with the original chunks.
        return da.from_array(labels_np, chunks=labels_da.chunks)

    keys = np.array(sorted(remap), dtype=np.int64)
    vals = np.array([remap[k] for k in keys], dtype=np.int64)

    def _remap_block(block, _keys=keys, _vals=vals):
        return _apply_label_remap(block, _keys, _vals)

    merged = da.from_array(labels_np, chunks=labels_da.chunks)
    return merged.map_blocks(
        _remap_block,
        dtype=np.float64,
        meta=np.array((), dtype=np.float64),
    )


def _run_dask_numpy(data):
    total_w = data.shape[1]

    def _tile_fn(block, block_info=None):
        if block_info is None or 0 not in block_info:
            return np.full(block.shape, np.nan, dtype=np.float64)
        row_off = block_info[0]['array-location'][0][0]
        col_off = block_info[0]['array-location'][1][0]
        h, w = block.shape
        return _sink_cpu(np.asarray(block, dtype=np.float64),
                         h, w, row_off, col_off, total_w)

    per_tile = da.map_blocks(
        _tile_fn, data,
        dtype=np.float64,
        meta=np.array((), dtype=np.float64),
    )

    return _merge_cross_tile_labels(per_tile)


def _run_dask_cupy(data):
    """Dask+CuPy: convert to numpy dask, run CPU path, convert back."""
    import cupy as cp

    data_np = data.map_blocks(
        lambda b: b.get(), dtype=data.dtype,
        meta=np.array((), dtype=data.dtype),
    )
    result = _run_dask_numpy(data_np)
    return result.map_blocks(
        cp.asarray, dtype=result.dtype,
        meta=cp.array((), dtype=result.dtype),
    )


# =====================================================================
# Public API
# =====================================================================

[docs] @supports_dataset def sink_d8(flow_dir: xr.DataArray, name: str = 'sink') -> xr.DataArray: """Identify and label depression cells in a D8 flow direction grid. Finds cells with direction code 0 (pit/flat with no downhill neighbor) and labels connected groups using 8-connected component labeling. Parameters ---------- flow_dir : xarray.DataArray or xr.Dataset 2D D8 flow direction grid (codes 0/1/2/4/8/16/32/64/128; NaN for nodata). name : str, default='sink' Name of output DataArray. Returns ------- xarray.DataArray or xr.Dataset 2D float64 array where each sink cell is labeled with a unique group ID. Non-sink cells and NaN cells are NaN. """ _validate_raster(flow_dir, func_name='sink', name='flow_dir') data = flow_dir.data if isinstance(data, np.ndarray): _check_memory(*data.shape) out = _run_numpy(data) elif has_cuda_and_cupy() and is_cupy_array(data): _check_gpu_memory(*data.shape) out = _sink_cupy(data) elif has_cuda_and_cupy() and is_dask_cupy(flow_dir): out = _run_dask_cupy(data) elif da is not None and isinstance(data, da.Array): out = _run_dask_numpy(data) else: raise TypeError(f"Unsupported array type: {type(data)}") return xr.DataArray(out, name=name, coords=flow_dir.coords, dims=flow_dir.dims, attrs=flow_dir.attrs)