Source code for xrspatial.hydro.flow_accumulation_mfd

"""Flow accumulation for Multiple Flow Direction (MFD) grids.

Takes the (8, H, W) fractional flow direction output from
``flow_direction_mfd`` and accumulates upstream area through all
downslope paths simultaneously.

Algorithm
---------
CPU : Kahn's BFS topological sort -- O(N).
GPU : iterative frontier peeling with pull-based kernels.
Dask: iterative tile sweep with boundary propagation (one tile in
      RAM at a time), following the ``flow_accumulation.py`` pattern.
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from numba import cuda

try:
    import cupy
except ImportError:
    class cupy:  # type: ignore[no-redef]
        ndarray = False

try:
    import dask.array as da
except ImportError:
    da = None

from xrspatial.utils import (
    _validate_mfd_fractions,
    _validate_raster,
    cuda_args,
    has_cuda_and_cupy,
    is_cupy_array,
    is_dask_cupy,
    ngjit,
)
from xrspatial.hydro._boundary_store import BoundaryStore
from xrspatial.dataset_support import supports_dataset


# =====================================================================
# Memory guards
# =====================================================================
#
# CPU peak working set per pixel for ``_flow_accum_mfd_cpu``:
#   accum    : float64 -> 8
#   in_degree: int32   -> 4
#   valid    : int8    -> 1
#   queue_r  : int64   -> 8
#   queue_c  : int64   -> 8
# Total ~29 bytes/pixel.  The caller-provided ``flow_dir_mfd`` array
# already lives in RAM before the kernel runs and is not double-counted
# here -- this matches the convention in ``flow_accumulation_d8``.
_BYTES_PER_PIXEL = 29

# GPU peak working set per pixel for ``_flow_accum_mfd_cupy``:
#   accum     : float64 -> 8
#   in_degree : int32   -> 4
#   state     : int32   -> 4
# Total ~16 bytes/pixel.  ``fractions_data`` already lives on the device
# before the kernel runs and is not double-counted here.
_GPU_BYTES_PER_PIXEL = 16


def _available_memory_bytes():
    """Best-effort estimate of available host memory in bytes."""
    try:
        with open('/proc/meminfo', 'r') as f:
            for line in f:
                if line.startswith('MemAvailable:'):
                    return int(line.split()[1]) * 1024  # kB -> bytes
    except (OSError, ValueError, IndexError):
        pass
    try:
        import psutil
        return psutil.virtual_memory().available
    except (ImportError, AttributeError):
        pass
    return 2 * 1024 ** 3


def _available_gpu_memory_bytes():
    """Best-effort estimate of free GPU memory in bytes.

    Returns 0 if CuPy / CUDA is unavailable or the query fails -- callers
    use that as a sentinel meaning "no GPU info, skip the guard".
    """
    try:
        import cupy as _cp
        free, _total = _cp.cuda.runtime.memGetInfo()
        return int(free)
    except Exception:
        return 0


def _check_memory(height, width):
    """Raise MemoryError if the BFS kernel would exceed 50% of RAM."""
    required = int(height) * int(width) * _BYTES_PER_PIXEL
    available = _available_memory_bytes()
    if required > 0.5 * available:
        raise MemoryError(
            f"flow_accumulation_mfd on a {height}x{width} grid requires "
            f"~{required / 1e9:.1f} GB of working memory but only "
            f"~{available / 1e9:.1f} GB is available.  Use a "
            f"dask-backed DataArray for out-of-core processing."
        )


def _check_gpu_memory(height, width):
    """Raise MemoryError if the CuPy kernel would exceed 50% of free GPU RAM.

    Skips the check (returns silently) when ``_available_gpu_memory_bytes``
    cannot determine the free memory -- e.g. on hosts without CUDA, where
    the kernel will fail at the cupy.asarray boundary anyway.
    """
    available = _available_gpu_memory_bytes()
    if available <= 0:
        return
    required = int(height) * int(width) * _GPU_BYTES_PER_PIXEL
    if required > 0.5 * available:
        raise MemoryError(
            f"flow_accumulation_mfd on a {height}x{width} grid requires "
            f"~{required / 1e9:.1f} GB of GPU working memory but only "
            f"~{available / 1e9:.1f} GB is free on the active device.  "
            f"Use a dask+cupy DataArray for out-of-core processing."
        )


# Neighbor offsets: E, SE, S, SW, W, NW, N, NE
_DY = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int64)
_DX = np.array([1, 1, 0, -1, -1, -1, 0, 1], dtype=np.int64)

# Opposite neighbor index (who points back at me?)
# E(0)->W(4), SE(1)->NW(5), S(2)->N(6), SW(3)->NE(7), ...
_OPPOSITE = np.array([4, 5, 6, 7, 0, 1, 2, 3], dtype=np.int64)


# =====================================================================
# CPU kernel
# =====================================================================

@ngjit
def _flow_accum_mfd_cpu(fractions, height, width):
    """Kahn's BFS topological sort for MFD flow accumulation.

    Parameters
    ----------
    fractions : (8, H, W) float64 array of flow fractions
    height, width : int

    Returns
    -------
    accum : (H, W) float64 array
    """
    dy = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int64)
    dx = np.array([1, 1, 0, -1, -1, -1, 0, 1], dtype=np.int64)

    accum = np.empty((height, width), dtype=np.float64)
    in_degree = np.zeros((height, width), dtype=np.int32)
    valid = np.zeros((height, width), dtype=np.int8)

    # Pass 1: initialise
    n_valid = 0
    for r in range(height):
        for c in range(width):
            v = fractions[0, r, c]
            if v != v:  # NaN
                accum[r, c] = np.nan
            else:
                valid[r, c] = 1
                accum[r, c] = 1.0
                n_valid += 1

    # Pass 2: compute in-degrees
    for r in range(height):
        for c in range(width):
            if valid[r, c] == 0:
                continue
            for k in range(8):
                if fractions[k, r, c] > 0.0:
                    nr = r + dy[k]
                    nc = c + dx[k]
                    if 0 <= nr < height and 0 <= nc < width:
                        if valid[nr, nc] == 1:
                            in_degree[nr, nc] += 1

    # BFS queue
    queue_r = np.empty(height * width, dtype=np.int64)
    queue_c = np.empty(height * width, dtype=np.int64)
    head = np.int64(0)
    tail = np.int64(0)

    for r in range(height):
        for c in range(width):
            if valid[r, c] == 1 and in_degree[r, c] == 0:
                queue_r[tail] = r
                queue_c[tail] = c
                tail += 1

    while head < tail:
        r = queue_r[head]
        c = queue_c[head]
        head += 1

        for k in range(8):
            frac = fractions[k, r, c]
            if frac > 0.0:
                nr = r + dy[k]
                nc = c + dx[k]
                if 0 <= nr < height and 0 <= nc < width:
                    if valid[nr, nc] == 1:
                        accum[nr, nc] += accum[r, c] * frac
                        in_degree[nr, nc] -= 1
                        if in_degree[nr, nc] == 0:
                            queue_r[tail] = nr
                            queue_c[tail] = nc
                            tail += 1

    # If a cycle remains, some valid cells never reached in_degree 0 and
    # were never dequeued.  head counts the cells processed by the BFS.
    if head < n_valid:
        raise ValueError(
            "flow_accumulation_mfd: the MFD fraction grid contains a cycle "
            "(some cells never reach zero in-degree).  The input must be a "
            "directed acyclic graph, as produced by flow_direction_mfd."
        )

    return accum


# =====================================================================
# GPU kernels
# =====================================================================

@cuda.jit
def _init_accum_indegree_mfd(fractions, accum, in_degree, state, H, W):
    """Initialise accum, in_degree and state for MFD on GPU."""
    i, j = cuda.grid(2)
    if i >= H or j >= W:
        return

    v = fractions[0, i, j]
    if v != v:  # NaN
        state[i, j] = 0
        accum[i, j] = 0.0
        return

    state[i, j] = 1
    accum[i, j] = 1.0

    # Neighbor offsets: E, SE, S, SW, W, NW, N, NE
    for k in range(8):
        frac = fractions[k, i, j]
        if frac <= 0.0:
            continue

        if k == 0:
            dy, dx = 0, 1
        elif k == 1:
            dy, dx = 1, 1
        elif k == 2:
            dy, dx = 1, 0
        elif k == 3:
            dy, dx = 1, -1
        elif k == 4:
            dy, dx = 0, -1
        elif k == 5:
            dy, dx = -1, -1
        elif k == 6:
            dy, dx = -1, 0
        else:
            dy, dx = -1, 1

        ni = i + dy
        nj = j + dx
        if 0 <= ni < H and 0 <= nj < W:
            cuda.atomic.add(in_degree, (ni, nj), 1)


@cuda.jit
def _find_ready_and_finalize_mfd(in_degree, state, changed, H, W):
    """Finalize previous frontier (2->3), mark new frontier (1->2)."""
    i, j = cuda.grid(2)
    if i >= H or j >= W:
        return

    if state[i, j] == 2:
        state[i, j] = 3

    if state[i, j] == 1 and in_degree[i, j] == 0:
        state[i, j] = 2
        cuda.atomic.add(changed, 0, 1)


@cuda.jit
def _pull_from_frontier_mfd(fractions, accum, in_degree, state, H, W):
    """Active MFD cells pull accumulation from frontier neighbours."""
    i, j = cuda.grid(2)
    if i >= H or j >= W:
        return

    if state[i, j] != 1:
        return

    # Opposite direction index: if neighbor k sent flow in direction k,
    # I am the opposite direction from them.
    # E(0)->W(4), SE(1)->NW(5), S(2)->N(6), SW(3)->NE(7), etc.
    for nbr in range(8):
        if nbr == 0:
            dy, dx = 0, 1
        elif nbr == 1:
            dy, dx = 1, 1
        elif nbr == 2:
            dy, dx = 1, 0
        elif nbr == 3:
            dy, dx = 1, -1
        elif nbr == 4:
            dy, dx = 0, -1
        elif nbr == 5:
            dy, dx = -1, -1
        elif nbr == 6:
            dy, dx = -1, 0
        else:
            dy, dx = -1, 1

        ni = i + dy
        nj = j + dx
        if ni < 0 or ni >= H or nj < 0 or nj >= W:
            continue
        if state[ni, nj] != 2:
            continue

        # Opposite of nbr: the direction from neighbor back to me
        if nbr == 0:
            opp = 4
        elif nbr == 1:
            opp = 5
        elif nbr == 2:
            opp = 6
        elif nbr == 3:
            opp = 7
        elif nbr == 4:
            opp = 0
        elif nbr == 5:
            opp = 1
        elif nbr == 6:
            opp = 2
        else:
            opp = 3

        frac = fractions[opp, ni, nj]
        if frac > 0.0:
            accum[i, j] += accum[ni, nj] * frac
            in_degree[i, j] -= 1


def _flow_accum_mfd_cupy(fractions_data):
    """GPU driver: iterative frontier peeling for MFD."""
    import cupy as cp

    _, H, W = fractions_data.shape
    fractions_f64 = fractions_data.astype(cp.float64)

    accum = cp.zeros((H, W), dtype=cp.float64)
    in_degree = cp.zeros((H, W), dtype=cp.int32)
    state = cp.zeros((H, W), dtype=cp.int32)
    changed = cp.zeros(1, dtype=cp.int32)

    griddim, blockdim = cuda_args((H, W))

    _init_accum_indegree_mfd[griddim, blockdim](
        fractions_f64, accum, in_degree, state, H, W)

    max_iter = H * W
    for _ in range(max_iter):
        changed[0] = 0
        _find_ready_and_finalize_mfd[griddim, blockdim](
            in_degree, state, changed, H, W)

        if int(changed[0]) == 0:
            break

        _pull_from_frontier_mfd[griddim, blockdim](
            fractions_f64, accum, in_degree, state, H, W)

    accum = cp.where(state == 0, cp.nan, accum)
    return accum


# =====================================================================
# Dask tile kernel
# =====================================================================

@ngjit
def _flow_accum_mfd_tile_kernel(fractions, h, w,
                                 seed_top, seed_bottom,
                                 seed_left, seed_right,
                                 seed_tl, seed_tr, seed_bl, seed_br):
    """Seeded BFS MFD flow accumulation for a single tile.

    Parameters
    ----------
    fractions : (8, h, w) float64 -- MFD flow fractions for this tile
    """
    dy = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int64)
    dx = np.array([1, 1, 0, -1, -1, -1, 0, 1], dtype=np.int64)

    accum = np.empty((h, w), dtype=np.float64)
    in_degree = np.zeros((h, w), dtype=np.int32)
    valid = np.zeros((h, w), dtype=np.int8)

    # Initialise
    n_valid = 0
    for r in range(h):
        for c in range(w):
            v = fractions[0, r, c]
            if v == v:  # not NaN
                valid[r, c] = 1
                accum[r, c] = 1.0
                n_valid += 1
            else:
                accum[r, c] = np.nan

    # Add external seeds
    for c in range(w):
        if valid[0, c] == 1:
            accum[0, c] += seed_top[c]
        if valid[h - 1, c] == 1:
            accum[h - 1, c] += seed_bottom[c]
    for r in range(h):
        if valid[r, 0] == 1:
            accum[r, 0] += seed_left[r]
        if valid[r, w - 1] == 1:
            accum[r, w - 1] += seed_right[r]

    if valid[0, 0] == 1:
        accum[0, 0] += seed_tl
    if valid[0, w - 1] == 1:
        accum[0, w - 1] += seed_tr
    if valid[h - 1, 0] == 1:
        accum[h - 1, 0] += seed_bl
    if valid[h - 1, w - 1] == 1:
        accum[h - 1, w - 1] += seed_br

    # Compute in-degrees
    for r in range(h):
        for c in range(w):
            if valid[r, c] == 0:
                continue
            for k in range(8):
                if fractions[k, r, c] > 0.0:
                    nr = r + dy[k]
                    nc = c + dx[k]
                    if 0 <= nr < h and 0 <= nc < w and valid[nr, nc] == 1:
                        in_degree[nr, nc] += 1

    # BFS
    queue_r = np.empty(h * w, dtype=np.int64)
    queue_c = np.empty(h * w, dtype=np.int64)
    head = np.int64(0)
    tail = np.int64(0)

    for r in range(h):
        for c in range(w):
            if valid[r, c] == 1 and in_degree[r, c] == 0:
                queue_r[tail] = r
                queue_c[tail] = c
                tail += 1

    while head < tail:
        r = queue_r[head]
        c = queue_c[head]
        head += 1

        for k in range(8):
            frac = fractions[k, r, c]
            if frac > 0.0:
                nr = r + dy[k]
                nc = c + dx[k]
                if 0 <= nr < h and 0 <= nc < w and valid[nr, nc] == 1:
                    accum[nr, nc] += accum[r, c] * frac
                    in_degree[nr, nc] -= 1
                    if in_degree[nr, nc] == 0:
                        queue_r[tail] = nr
                        queue_c[tail] = nc
                        tail += 1

    # A cycle within this tile leaves some valid cells undequeued.
    if head < n_valid:
        raise ValueError(
            "flow_accumulation_mfd: the MFD fraction grid contains a cycle "
            "(some cells never reach zero in-degree).  The input must be a "
            "directed acyclic graph, as produced by flow_direction_mfd."
        )

    return accum


# =====================================================================
# Dask iterative tile sweep
# =====================================================================

def _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x):
    """Extract boundary fraction strips into a dict.

    For MFD we need the full 8-band fractions at each boundary cell,
    so we store them as (8, length) arrays.
    """
    n_tile_y = len(chunks_y)
    n_tile_x = len(chunks_x)

    # Store fraction strips keyed by (side, iy, ix)
    frac_bdry = {}

    for iy in range(n_tile_y):
        for ix in range(n_tile_x):
            # fractions_da is (8, H, W) dask array
            # Each tile's fractions: shape (8, tile_h, tile_w)
            chunk = fractions_da[:, sum(chunks_y[:iy]):sum(chunks_y[:iy+1]),
                                    sum(chunks_x[:ix]):sum(chunks_x[:ix+1])].compute()
            chunk = np.asarray(chunk, dtype=np.float64)

            # top row: (8, tile_w)
            frac_bdry[('top', iy, ix)] = chunk[:, 0, :].copy()
            # bottom row: (8, tile_w)
            frac_bdry[('bottom', iy, ix)] = chunk[:, -1, :].copy()
            # left col: (8, tile_h)
            frac_bdry[('left', iy, ix)] = chunk[:, :, 0].copy()
            # right col: (8, tile_h)
            frac_bdry[('right', iy, ix)] = chunk[:, :, -1].copy()

    return frac_bdry


def _compute_seeds_mfd(iy, ix, boundaries, frac_bdry,
                        chunks_y, chunks_x, n_tile_y, n_tile_x):
    """Compute seed arrays for tile (iy, ix) from neighbour boundaries.

    For MFD, a neighbor cell flows into the current tile if its fraction
    for the direction pointing into our tile is > 0.
    """
    # Neighbor offsets: E(0), SE(1), S(2), SW(3), W(4), NW(5), N(6), NE(7)
    # Opposite:         W(4), NW(5), N(6), NE(7), E(0), SE(1), S(2), SE(3)

    tile_h = chunks_y[iy]
    tile_w = chunks_x[ix]

    seed_top = np.zeros(tile_w, dtype=np.float64)
    seed_bottom = np.zeros(tile_w, dtype=np.float64)
    seed_left = np.zeros(tile_h, dtype=np.float64)
    seed_right = np.zeros(tile_h, dtype=np.float64)
    seed_tl = 0.0
    seed_tr = 0.0
    seed_bl = 0.0
    seed_br = 0.0

    dy_arr = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int64)
    dx_arr = np.array([1, 1, 0, -1, -1, -1, 0, 1], dtype=np.int64)

    # --- Top edge: bottom row of tile above ---
    if iy > 0:
        nb_frac = frac_bdry[('bottom', iy - 1, ix)]  # (8, tile_w)
        nb_accum = boundaries.get('bottom', iy - 1, ix)
        w = nb_frac.shape[1]
        for c in range(w):
            for k in range(8):
                if not (nb_frac[k, c] > 0.0):
                    continue
                # Direction k from neighbor: dy_arr[k], dx_arr[k]
                # Neighbor is in row above, so dy must be +1 to enter our tile
                ndy = dy_arr[k]
                ndx = dx_arr[k]
                if ndy == 1:  # flows south into our tile
                    tc = c + ndx
                    if 0 <= tc < tile_w:
                        seed_top[tc] += nb_accum[c] * nb_frac[k, c]

    # --- Bottom edge: top row of tile below ---
    if iy < n_tile_y - 1:
        nb_frac = frac_bdry[('top', iy + 1, ix)]  # (8, tile_w)
        nb_accum = boundaries.get('top', iy + 1, ix)
        w = nb_frac.shape[1]
        for c in range(w):
            for k in range(8):
                if not (nb_frac[k, c] > 0.0):
                    continue
                ndy = dy_arr[k]
                ndx = dx_arr[k]
                if ndy == -1:  # flows north into our tile
                    tc = c + ndx
                    if 0 <= tc < tile_w:
                        seed_bottom[tc] += nb_accum[c] * nb_frac[k, c]

    # --- Left edge: right column of tile to the left ---
    if ix > 0:
        nb_frac = frac_bdry[('right', iy, ix - 1)]  # (8, tile_h)
        nb_accum = boundaries.get('right', iy, ix - 1)
        h = nb_frac.shape[1]
        for r in range(h):
            for k in range(8):
                if not (nb_frac[k, r] > 0.0):
                    continue
                ndy = dy_arr[k]
                ndx = dx_arr[k]
                if ndx == 1:  # flows east into our tile
                    tr = r + ndy
                    if 0 <= tr < tile_h:
                        seed_left[tr] += nb_accum[r] * nb_frac[k, r]

    # --- Right edge: left column of tile to the right ---
    if ix < n_tile_x - 1:
        nb_frac = frac_bdry[('left', iy, ix + 1)]  # (8, tile_h)
        nb_accum = boundaries.get('left', iy, ix + 1)
        h = nb_frac.shape[1]
        for r in range(h):
            for k in range(8):
                if not (nb_frac[k, r] > 0.0):
                    continue
                ndy = dy_arr[k]
                ndx = dx_arr[k]
                if ndx == -1:  # flows west into our tile
                    tr = r + ndy
                    if 0 <= tr < tile_h:
                        seed_right[tr] += nb_accum[r] * nb_frac[k, r]

    # --- Diagonal corner seeds ---
    # TL: bottom-right cell of (iy-1, ix-1) flows SE (dy=1, dx=1 -> k=1)
    if iy > 0 and ix > 0:
        nb_frac = frac_bdry[('bottom', iy - 1, ix - 1)]  # (8, w)
        av = float(boundaries.get('bottom', iy - 1, ix - 1)[-1])
        frac_se = nb_frac[1, -1]  # SE direction
        if frac_se > 0.0:
            seed_tl += av * frac_se

    # TR: bottom-left cell of (iy-1, ix+1) flows SW (dy=1, dx=-1 -> k=3)
    if iy > 0 and ix < n_tile_x - 1:
        nb_frac = frac_bdry[('bottom', iy - 1, ix + 1)]  # (8, w)
        av = float(boundaries.get('bottom', iy - 1, ix + 1)[0])
        frac_sw = nb_frac[3, 0]  # SW direction
        if frac_sw > 0.0:
            seed_tr += av * frac_sw

    # BL: top-right cell of (iy+1, ix-1) flows NE (dy=-1, dx=1 -> k=7)
    if iy < n_tile_y - 1 and ix > 0:
        nb_frac = frac_bdry[('top', iy + 1, ix - 1)]  # (8, w)
        av = float(boundaries.get('top', iy + 1, ix - 1)[-1])
        frac_ne = nb_frac[7, -1]  # NE direction
        if frac_ne > 0.0:
            seed_bl += av * frac_ne

    # BR: top-left cell of (iy+1, ix+1) flows NW (dy=-1, dx=-1 -> k=5)
    if iy < n_tile_y - 1 and ix < n_tile_x - 1:
        nb_frac = frac_bdry[('top', iy + 1, ix + 1)]  # (8, w)
        av = float(boundaries.get('top', iy + 1, ix + 1)[0])
        frac_nw = nb_frac[5, 0]  # NW direction
        if frac_nw > 0.0:
            seed_br += av * frac_nw

    return (seed_top, seed_bottom, seed_left, seed_right,
            seed_tl, seed_tr, seed_bl, seed_br)


def _process_tile_mfd(iy, ix, fractions_da, boundaries, frac_bdry,
                       chunks_y, chunks_x, n_tile_y, n_tile_x):
    """Run seeded MFD BFS on one tile; update boundaries in-place."""
    # Extract this tile's fractions: (8, tile_h, tile_w)
    y_start = sum(chunks_y[:iy])
    y_end = y_start + chunks_y[iy]
    x_start = sum(chunks_x[:ix])
    x_end = x_start + chunks_x[ix]

    chunk = np.asarray(
        fractions_da[:, y_start:y_end, x_start:x_end].compute(),
        dtype=np.float64)
    _, h, w = chunk.shape

    seeds = _compute_seeds_mfd(
        iy, ix, boundaries, frac_bdry,
        chunks_y, chunks_x, n_tile_y, n_tile_x)

    accum = _flow_accum_mfd_tile_kernel(chunk, h, w, *seeds)

    # NaN cells don't contribute flow; replace with 0 for boundary storage
    new_top = np.where(np.isnan(accum[0, :]), 0.0, accum[0, :])
    new_bottom = np.where(np.isnan(accum[-1, :]), 0.0, accum[-1, :])
    new_left = np.where(np.isnan(accum[:, 0]), 0.0, accum[:, 0])
    new_right = np.where(np.isnan(accum[:, -1]), 0.0, accum[:, -1])

    change = 0.0
    for side, new in (('top', new_top), ('bottom', new_bottom),
                      ('left', new_left), ('right', new_right)):
        old = boundaries.get(side, iy, ix)
        with np.errstate(invalid='ignore'):
            diff = np.abs(new - old)
        diff = np.where(np.isnan(diff), 0.0, diff)
        m = float(np.max(diff))
        if m > change:
            change = m

    boundaries.set('top', iy, ix, new_top)
    boundaries.set('bottom', iy, ix, new_bottom)
    boundaries.set('left', iy, ix, new_left)
    boundaries.set('right', iy, ix, new_right)

    return change


def _flow_accum_mfd_dask_iterative(fractions_da, chunks_y, chunks_x):
    """Iterative boundary-propagation for MFD dask arrays.

    Parameters
    ----------
    fractions_da : dask array of shape (8, H, W)
    chunks_y, chunks_x : tuples of chunk sizes for the spatial dims
    """
    n_tile_y = len(chunks_y)
    n_tile_x = len(chunks_x)

    # The 8 direction bands must stay in a single chunk: every tile kernel
    # needs all 8 fractions, and the lazy assembly drops axis 0 per block.
    if fractions_da.chunks[0] != (fractions_da.shape[0],):
        fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]})

    # Phase 0: extract boundary fraction strips
    frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x)

    # Phase 1: initialise boundary accum to 0
    boundaries = BoundaryStore(chunks_y, chunks_x, fill_value=0.0)

    # Phase 2: iterative forward/backward sweeps
    max_iterations = max(n_tile_y, n_tile_x) + 10

    for _iteration in range(max_iterations):
        max_change = 0.0

        for iy in range(n_tile_y):
            for ix in range(n_tile_x):
                c = _process_tile_mfd(iy, ix, fractions_da, boundaries,
                                       frac_bdry, chunks_y, chunks_x,
                                       n_tile_y, n_tile_x)
                if c > max_change:
                    max_change = c

        for iy in reversed(range(n_tile_y)):
            for ix in reversed(range(n_tile_x)):
                c = _process_tile_mfd(iy, ix, fractions_da, boundaries,
                                       frac_bdry, chunks_y, chunks_x,
                                       n_tile_y, n_tile_x)
                if c > max_change:
                    max_change = c

        if max_change == 0.0:
            break

    boundaries = boundaries.snapshot()

    return _assemble_result_mfd(fractions_da, boundaries, frac_bdry,
                                 chunks_y, chunks_x, n_tile_y, n_tile_x)


def _assemble_result_mfd(fractions_da, boundaries, frac_bdry,
                          chunks_y, chunks_x, n_tile_y, n_tile_x):
    """Build a lazy dask array by re-running each MFD tile with converged seeds.

    fractions_da is (8, H, W) chunked one tile per (chunks_y, chunks_x)
    block.  The converged boundary snapshot and fraction strips are small,
    so we capture them in a closure and let ``map_blocks`` run the per-tile
    kernel at compute time.  Nothing here materializes the full output
    raster during the API call.
    """
    # Cumulative tile-start offsets to map a block's spatial origin to (iy, ix).
    y_starts = np.cumsum((0,) + tuple(chunks_y[:-1]))
    x_starts = np.cumsum((0,) + tuple(chunks_x[:-1]))

    def _tile(chunk, block_info=None):
        # block_info[0]['array-location'] gives ((0, 8), (y0, y1), (x0, x1)).
        loc = block_info[0]['array-location']
        y0 = loc[1][0]
        x0 = loc[2][0]
        iy = int(np.searchsorted(y_starts, y0, side='right')) - 1
        ix = int(np.searchsorted(x_starts, x0, side='right')) - 1

        chunk = np.asarray(chunk, dtype=np.float64)
        _, h, w = chunk.shape
        seeds = _compute_seeds_mfd(
            iy, ix, boundaries, frac_bdry,
            chunks_y, chunks_x, n_tile_y, n_tile_x)
        return _flow_accum_mfd_tile_kernel(chunk, h, w, *seeds)

    return da.map_blocks(
        _tile, fractions_da, drop_axis=0,
        dtype=np.float64, meta=np.array((), dtype=np.float64),
    )


def _flow_accum_mfd_dask_cupy(fractions_da, chunks_y, chunks_x):
    """Dask+CuPy MFD: convert to numpy, run iterative, convert back."""
    import cupy as cp

    fractions_np = fractions_da.map_blocks(
        lambda b: b.get(), dtype=fractions_da.dtype,
        meta=np.array((), dtype=fractions_da.dtype),
    )
    result = _flow_accum_mfd_dask_iterative(fractions_np, chunks_y, chunks_x)
    return result.map_blocks(
        cp.asarray, dtype=result.dtype,
        meta=cp.array((), dtype=result.dtype),
    )


# =====================================================================
# Public API
# =====================================================================

[docs] @supports_dataset def flow_accumulation_mfd(flow_dir_mfd: xr.DataArray, name: str = 'flow_accumulation_mfd') -> xr.DataArray: """Compute flow accumulation from an MFD flow direction grid. Takes the 3-D fractional output of ``flow_direction_mfd`` and accumulates upstream contributing area through all downslope paths simultaneously. Each cell starts with a value of 1 (itself) and passes fractions of its accumulated value to each downstream neighbor. Parameters ---------- flow_dir_mfd : xarray.DataArray or xr.Dataset 3-D MFD flow direction array of shape ``(8, H, W)`` as returned by ``flow_direction_mfd``. Values are flow fractions in ``[0, 1]`` that sum to 1.0 at each cell (0.0 at pits/flats, NaN at edges or nodata cells). Supported backends: NumPy, CuPy, NumPy-backed Dask, CuPy-backed Dask. If a Dataset is passed, the operation is applied to each data variable independently. name : str, default='flow_accumulation_mfd' Name of output DataArray. Returns ------- xarray.DataArray or xr.Dataset 2-D float64 array of flow accumulation values. Each cell holds the total upstream contributing area (including itself) that drains through it, weighted by MFD fractions. NaN where the input has NaN. References ---------- Qin, C., Zhu, A.X., Pei, T., Li, B., Zhou, C., and Yang, L. (2007). An adaptive approach to selecting a flow-partition exponent for a multiple-flow-direction algorithm. International Journal of Geographical Information Science, 21(4), 443-458. Quinn, P., Beven, K., Chevallier, P., and Planchon, O. (1991). The prediction of hillslope flow paths for distributed hydrological modelling using digital terrain models. Hydrological Processes, 5(1), 59-79. """ _validate_raster(flow_dir_mfd, func_name='flow_accumulation_mfd', name='flow_dir_mfd', ndim=3) data = flow_dir_mfd.data if data.ndim != 3 or data.shape[0] != 8: raise ValueError( "flow_dir_mfd must be a 3-D array of shape (8, H, W), " f"got shape {data.shape}" ) _validate_mfd_fractions(data, func_name='flow_accumulation_mfd', name='flow_dir_mfd') if isinstance(data, np.ndarray): _check_memory(data.shape[1], data.shape[2]) out = _flow_accum_mfd_cpu( data.astype(np.float64), data.shape[1], data.shape[2]) elif has_cuda_and_cupy() and is_cupy_array(data): _check_gpu_memory(data.shape[1], data.shape[2]) out = _flow_accum_mfd_cupy(data) elif has_cuda_and_cupy() and is_dask_cupy(flow_dir_mfd): # Spatial chunk sizes from dims 1 and 2 chunks_y = data.chunks[1] chunks_x = data.chunks[2] out = _flow_accum_mfd_dask_cupy(data, chunks_y, chunks_x) elif da is not None and isinstance(data, da.Array): chunks_y = data.chunks[1] chunks_x = data.chunks[2] out = _flow_accum_mfd_dask_iterative(data, chunks_y, chunks_x) else: raise TypeError(f"Unsupported array type: {type(data)}") # Build 2-D output coords (drop 'neighbor' dim) spatial_dims = flow_dir_mfd.dims[1:] coords = {k: v for k, v in flow_dir_mfd.coords.items() if k != 'neighbor' and k not in flow_dir_mfd.dims[:1]} # Copy spatial coordinate arrays for d in spatial_dims: if d in flow_dir_mfd.coords: coords[d] = flow_dir_mfd.coords[d] return xr.DataArray(out, name=name, coords=coords, dims=spatial_dims, attrs=flow_dir_mfd.attrs)