Source code for xrspatial.hydro.watershed_mfd

"""MFD watershed delineation.

Labels each cell with the pour point it drains to, using MFD
dominant-neighbor downstream tracing with path compression.

Algorithm
---------
CPU : downstream tracing with path compression, following the neighbor
      with the highest fraction at each step.
GPU : CuPy-via-CPU.
Dask: iterative tile sweep with exit-label propagation.
"""

from __future__ import annotations

import numpy as np
import xarray as xr

try:
    import cupy
except ImportError:
    class cupy:  # type: ignore[no-redef]
        ndarray = False

try:
    import dask.array as da
except ImportError:
    da = None

from xrspatial.hydro._boundary_store import BoundaryStore
from xrspatial.utils import (
    _validate_matching_shape,
    _validate_mfd_fractions,
    _validate_raster,
    has_cuda_and_cupy,
    is_cupy_array,
    is_dask_cupy,
    ngjit,
)
from xrspatial.dataset_support import supports_dataset

_DY_LIST = [0, 1, 1, 1, 0, -1, -1, -1]
_DX_LIST = [1, 1, 0, -1, -1, -1, 0, 1]


def _to_numpy_f64(arr):
    if hasattr(arr, 'get'):
        arr = arr.get()
    return np.asarray(arr, dtype=np.float64)


# =====================================================================
# Memory guards
# =====================================================================
#
# CPU peak working set per pixel for the numpy dispatch + the
# ``_watershed_mfd_cpu`` kernel:
#   fr (float64 cast of (8, H, W) fractions) -> 64
#   labels (float64)                         -> 8
#   state  (int8)                            -> 1
#   path_r (int64)                           -> 8
#   path_c (int64)                           -> 8
# Total ~97 bytes/pixel.  The MFD fractions buffer is the new cost
# relative to ``watershed_d8`` (which only needs an 8 B/pixel D8 array).
_BYTES_PER_PIXEL = 97

# GPU peak working set per pixel for ``_watershed_mfd_cupy``.  The
# function copies the device fractions array to the host, runs the CPU
# kernel, and ships the result back.  Device-resident peak is the
# caller's float64 fractions input (8 channels x 8 bytes = 64) plus the
# final ``cp.asarray(out)`` (8) -> 72 bytes/pixel.
_GPU_BYTES_PER_PIXEL = 72


def _available_memory_bytes():
    """Best-effort estimate of available host memory in bytes."""
    try:
        with open('/proc/meminfo', 'r') as f:
            for line in f:
                if line.startswith('MemAvailable:'):
                    return int(line.split()[1]) * 1024  # kB -> bytes
    except (OSError, ValueError, IndexError):
        pass
    try:
        import psutil
        return psutil.virtual_memory().available
    except (ImportError, AttributeError):
        pass
    return 2 * 1024 ** 3


def _available_gpu_memory_bytes():
    """Best-effort estimate of free GPU memory in bytes.

    Returns 0 if CuPy / CUDA is unavailable or the query fails -- callers
    use that as a sentinel meaning "no GPU info, skip the guard".
    """
    try:
        import cupy as _cp
        free, _total = _cp.cuda.runtime.memGetInfo()
        return int(free)
    except Exception:
        return 0


def _check_memory(height, width):
    """Raise MemoryError if the watershed_mfd kernel would exceed 50% of RAM."""
    required = int(height) * int(width) * _BYTES_PER_PIXEL
    available = _available_memory_bytes()
    if required > 0.5 * available:
        raise MemoryError(
            f"watershed_mfd on a {height}x{width} grid requires "
            f"~{required / 1e9:.1f} GB of working memory but only "
            f"~{available / 1e9:.1f} GB is available.  Use a "
            f"dask-backed DataArray for out-of-core processing."
        )


def _check_gpu_memory(height, width):
    """Raise MemoryError if the cupy path would exceed 50% of free GPU RAM.

    Skips the check (returns silently) when ``_available_gpu_memory_bytes``
    cannot determine the free memory -- e.g. on hosts without CUDA, where
    the kernel will fail at the cupy.asarray boundary anyway.
    """
    available = _available_gpu_memory_bytes()
    if available <= 0:
        return
    required = int(height) * int(width) * _GPU_BYTES_PER_PIXEL
    if required > 0.5 * available:
        raise MemoryError(
            f"watershed_mfd on a {height}x{width} grid requires "
            f"~{required / 1e9:.1f} GB of GPU working memory but only "
            f"~{available / 1e9:.1f} GB is free on the active device.  "
            f"Use a dask+cupy DataArray for out-of-core processing."
        )


def _dominant_offset_mfd_py(fractions_8):
    """Return (dy, dx) of dominant MFD neighbor, or (0,0) for pit/nodata."""
    best_k = -1
    best_f = 0.0
    for k in range(8):
        f = float(fractions_8[k])
        if f > best_f:
            best_f = f
            best_k = k
    if best_k == -1:
        return (0, 0)
    return (_DY_LIST[best_k], _DX_LIST[best_k])


# =====================================================================
# CPU kernel
# =====================================================================

@ngjit
def _watershed_mfd_cpu(fractions, labels, state, h, w):
    """Downstream tracing with path compression for MFD watershed."""
    dy = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int64)
    dx = np.array([1, 1, 0, -1, -1, -1, 0, 1], dtype=np.int64)

    path_r = np.empty(h * w, dtype=np.int64)
    path_c = np.empty(h * w, dtype=np.int64)

    for r in range(h):
        for c in range(w):
            if state[r, c] != 1:
                continue

            path_len = 0
            cr, cc = r, c
            found_label = np.nan
            found = False

            while True:
                s = state[cr, cc]
                if s == 3:
                    found_label = labels[cr, cc]
                    found = True
                    break
                if s != 1:
                    break

                path_r[path_len] = cr
                path_c[path_len] = cc
                path_len += 1
                state[cr, cc] = 2

                chk = fractions[0, cr, cc]
                if chk != chk:  # NaN
                    break

                best_k = -1
                best_frac = 0.0
                for k in range(8):
                    f = fractions[k, cr, cc]
                    if f > best_frac:
                        best_frac = f
                        best_k = k
                if best_k == -1:
                    break

                nr, nc = cr + dy[best_k], cc + dx[best_k]
                if nr < 0 or nr >= h or nc < 0 or nc >= w:
                    break
                cr, cc = nr, nc

            for i in range(path_len):
                if found:
                    labels[path_r[i], path_c[i]] = found_label
                    state[path_r[i], path_c[i]] = 3
                else:
                    labels[path_r[i], path_c[i]] = np.nan
                    state[path_r[i], path_c[i]] = 0

    return labels


# =====================================================================
# CuPy backend
# =====================================================================

def _watershed_mfd_cupy(fractions_data, pour_points_data):
    import cupy as cp
    fr_np = _to_numpy_f64(fractions_data)
    pp_np = _to_numpy_f64(pour_points_data)
    _, h, w = fr_np.shape
    labels = np.full((h, w), np.nan, dtype=np.float64)
    state = np.zeros((h, w), dtype=np.int8)
    for r in range(h):
        for c in range(w):
            if fr_np[0, r, c] != fr_np[0, r, c]:
                pass
            elif pp_np[r, c] == pp_np[r, c]:
                labels[r, c] = pp_np[r, c]
                state[r, c] = 3
            else:
                state[r, c] = 1
    out = _watershed_mfd_cpu(fr_np, labels, state, h, w)
    return cp.asarray(out)


# =====================================================================
# Dask tile kernel
# =====================================================================

@ngjit
def _watershed_mfd_tile_kernel(fractions, h, w, pour_points,
                                exit_top, exit_bottom, exit_left, exit_right,
                                exit_tl, exit_tr, exit_bl, exit_br):
    """Seeded downstream tracing for an MFD tile."""
    dy = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int64)
    dx = np.array([1, 1, 0, -1, -1, -1, 0, 1], dtype=np.int64)

    labels = np.empty((h, w), dtype=np.float64)
    state = np.empty((h, w), dtype=np.int8)

    for r in range(h):
        for c in range(w):
            v = fractions[0, r, c]
            if v != v:
                labels[r, c] = np.nan
                state[r, c] = 0
                continue
            pp = pour_points[r, c]
            if pp == pp:
                labels[r, c] = pp
                state[r, c] = 3
                continue
            labels[r, c] = np.nan
            state[r, c] = 1

    # Apply exit labels
    for c in range(w):
        if state[0, c] == 1:
            el = exit_top[c]
            if el == el:
                labels[0, c] = el
                state[0, c] = 3
    for c in range(w):
        if state[h - 1, c] == 1:
            el = exit_bottom[c]
            if el == el:
                labels[h - 1, c] = el
                state[h - 1, c] = 3
    for r in range(h):
        if state[r, 0] == 1:
            el = exit_left[r]
            if el == el:
                labels[r, 0] = el
                state[r, 0] = 3
    for r in range(h):
        if state[r, w - 1] == 1:
            el = exit_right[r]
            if el == el:
                labels[r, w - 1] = el
                state[r, w - 1] = 3

    if state[0, 0] == 1 and exit_tl == exit_tl:
        labels[0, 0] = exit_tl
        state[0, 0] = 3
    if state[0, w - 1] == 1 and exit_tr == exit_tr:
        labels[0, w - 1] = exit_tr
        state[0, w - 1] = 3
    if state[h - 1, 0] == 1 and exit_bl == exit_bl:
        labels[h - 1, 0] = exit_bl
        state[h - 1, 0] = 3
    if state[h - 1, w - 1] == 1 and exit_br == exit_br:
        labels[h - 1, w - 1] = exit_br
        state[h - 1, w - 1] = 3

    # Downstream tracing
    path_r = np.empty(h * w, dtype=np.int64)
    path_c = np.empty(h * w, dtype=np.int64)

    for r in range(h):
        for c in range(w):
            if state[r, c] != 1:
                continue

            path_len = 0
            cr, cc = r, c
            found_label = np.nan
            found = False
            exit_tile = False

            while True:
                s = state[cr, cc]
                if s == 3:
                    found_label = labels[cr, cc]
                    found = True
                    break
                if s != 1:
                    break

                path_r[path_len] = cr
                path_c[path_len] = cc
                path_len += 1
                state[cr, cc] = 2

                chk = fractions[0, cr, cc]
                if chk != chk:
                    break

                best_k = -1
                best_frac = 0.0
                for k in range(8):
                    f = fractions[k, cr, cc]
                    if f > best_frac:
                        best_frac = f
                        best_k = k
                if best_k == -1:
                    break

                nr, nc = cr + dy[best_k], cc + dx[best_k]
                if nr < 0 or nr >= h or nc < 0 or nc >= w:
                    exit_tile = True
                    break
                cr, cc = nr, nc

            for i in range(path_len):
                if found:
                    labels[path_r[i], path_c[i]] = found_label
                    state[path_r[i], path_c[i]] = 3
                elif exit_tile:
                    state[path_r[i], path_c[i]] = 1
                else:
                    labels[path_r[i], path_c[i]] = np.nan
                    state[path_r[i], path_c[i]] = 0

    return labels


# =====================================================================
# Dask iterative tile sweep
# =====================================================================

def _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x):
    n_tile_y = len(chunks_y)
    n_tile_x = len(chunks_x)
    frac_bdry = {}
    for iy in range(n_tile_y):
        for ix in range(n_tile_x):
            chunk = fractions_da[:, sum(chunks_y[:iy]):sum(chunks_y[:iy+1]),
                                    sum(chunks_x[:ix]):sum(chunks_x[:ix+1])].compute()
            chunk = np.asarray(chunk, dtype=np.float64)
            frac_bdry[('top', iy, ix)] = chunk[:, 0, :].copy()
            frac_bdry[('bottom', iy, ix)] = chunk[:, -1, :].copy()
            frac_bdry[('left', iy, ix)] = chunk[:, :, 0].copy()
            frac_bdry[('right', iy, ix)] = chunk[:, :, -1].copy()
    return frac_bdry


def _compute_exit_labels_mfd(iy, ix, boundaries, frac_bdry,
                              chunks_y, chunks_x, n_tile_y, n_tile_x):
    """Compute exit labels for MFD tile using dominant neighbor."""
    tile_h = chunks_y[iy]
    tile_w = chunks_x[ix]

    exit_top = np.full(tile_w, np.nan)
    exit_bottom = np.full(tile_w, np.nan)
    exit_left = np.full(tile_h, np.nan)
    exit_right = np.full(tile_h, np.nan)
    exit_tl = np.nan
    exit_tr = np.nan
    exit_bl = np.nan
    exit_br = np.nan

    # Top row
    fdir_top = frac_bdry.get(('top', iy, ix))
    if fdir_top is not None and iy > 0:
        nb_labels = boundaries.get('bottom', iy - 1, ix)
        for j in range(tile_w):
            d = _dominant_offset_mfd_py(fdir_top[:, j])
            if d[0] == -1:
                dj = j + d[1]
                if 0 <= dj < len(nb_labels):
                    exit_top[j] = nb_labels[dj]
                elif dj < 0 and ix > 0:
                    exit_top[j] = boundaries.get('bottom', iy - 1, ix - 1)[-1]
                elif dj >= len(nb_labels) and ix < n_tile_x - 1:
                    exit_top[j] = boundaries.get('bottom', iy - 1, ix + 1)[0]

    # Bottom row
    fdir_bot = frac_bdry.get(('bottom', iy, ix))
    if fdir_bot is not None and iy < n_tile_y - 1:
        nb_labels = boundaries.get('top', iy + 1, ix)
        for j in range(tile_w):
            d = _dominant_offset_mfd_py(fdir_bot[:, j])
            if d[0] == 1:
                dj = j + d[1]
                if 0 <= dj < len(nb_labels):
                    exit_bottom[j] = nb_labels[dj]
                elif dj < 0 and ix > 0:
                    exit_bottom[j] = boundaries.get('top', iy + 1, ix - 1)[-1]
                elif dj >= len(nb_labels) and ix < n_tile_x - 1:
                    exit_bottom[j] = boundaries.get('top', iy + 1, ix + 1)[0]

    # Left column
    fdir_left = frac_bdry.get(('left', iy, ix))
    if fdir_left is not None and ix > 0:
        nb_labels = boundaries.get('right', iy, ix - 1)
        for r in range(tile_h):
            d = _dominant_offset_mfd_py(fdir_left[:, r])
            if d[1] == -1:
                dr = r + d[0]
                if 0 <= dr < len(nb_labels):
                    exit_left[r] = nb_labels[dr]

    # Right column
    fdir_right = frac_bdry.get(('right', iy, ix))
    if fdir_right is not None and ix < n_tile_x - 1:
        nb_labels = boundaries.get('left', iy, ix + 1)
        for r in range(tile_h):
            d = _dominant_offset_mfd_py(fdir_right[:, r])
            if d[1] == 1:
                dr = r + d[0]
                if 0 <= dr < len(nb_labels):
                    exit_right[r] = nb_labels[dr]

    # Edge-of-grid exits
    if iy == 0 and fdir_top is not None:
        for j in range(tile_w):
            d = _dominant_offset_mfd_py(fdir_top[:, j])
            if d[0] == -1:
                exit_top[j] = np.nan
    if iy == n_tile_y - 1 and fdir_bot is not None:
        for j in range(tile_w):
            d = _dominant_offset_mfd_py(fdir_bot[:, j])
            if d[0] == 1:
                exit_bottom[j] = np.nan
    if ix == 0 and fdir_left is not None:
        for r in range(tile_h):
            d = _dominant_offset_mfd_py(fdir_left[:, r])
            if d[1] == -1:
                exit_left[r] = np.nan
    if ix == n_tile_x - 1 and fdir_right is not None:
        for r in range(tile_h):
            d = _dominant_offset_mfd_py(fdir_right[:, r])
            if d[1] == 1:
                exit_right[r] = np.nan

    # Diagonal corners
    if fdir_top is not None:
        d = _dominant_offset_mfd_py(fdir_top[:, 0])
        if d == (-1, -1):
            if iy > 0 and ix > 0:
                exit_tl = boundaries.get('bottom', iy - 1, ix - 1)[-1]
            else:
                exit_tl = np.nan
        d = _dominant_offset_mfd_py(fdir_top[:, -1])
        if d == (-1, 1):
            if iy > 0 and ix < n_tile_x - 1:
                exit_tr = boundaries.get('bottom', iy - 1, ix + 1)[0]
            else:
                exit_tr = np.nan
    if fdir_bot is not None:
        d = _dominant_offset_mfd_py(fdir_bot[:, 0])
        if d == (1, -1):
            if iy < n_tile_y - 1 and ix > 0:
                exit_bl = boundaries.get('top', iy + 1, ix - 1)[-1]
            else:
                exit_bl = np.nan
        d = _dominant_offset_mfd_py(fdir_bot[:, -1])
        if d == (1, 1):
            if iy < n_tile_y - 1 and ix < n_tile_x - 1:
                exit_br = boundaries.get('top', iy + 1, ix + 1)[0]
            else:
                exit_br = np.nan

    return (exit_top, exit_bottom, exit_left, exit_right,
            exit_tl, exit_tr, exit_bl, exit_br)


def _process_tile_mfd(iy, ix, fractions_da, pour_points_da,
                       boundaries, frac_bdry,
                       chunks_y, chunks_x, n_tile_y, n_tile_x):
    y_start = sum(chunks_y[:iy])
    y_end = y_start + chunks_y[iy]
    x_start = sum(chunks_x[:ix])
    x_end = x_start + chunks_x[ix]

    chunk = np.asarray(
        fractions_da[:, y_start:y_end, x_start:x_end].compute(),
        dtype=np.float64)
    pp_chunk = np.asarray(
        pour_points_da.blocks[iy, ix].compute(), dtype=np.float64)
    _, h, w = chunk.shape

    exits = _compute_exit_labels_mfd(
        iy, ix, boundaries, frac_bdry,
        chunks_y, chunks_x, n_tile_y, n_tile_x)

    result = _watershed_mfd_tile_kernel(chunk, h, w, pp_chunk, *exits)

    new_top = result[0, :].copy()
    new_bottom = result[-1, :].copy()
    new_left = result[:, 0].copy()
    new_right = result[:, -1].copy()

    changed = False
    for side, new in (('top', new_top), ('bottom', new_bottom),
                      ('left', new_left), ('right', new_right)):
        old = boundaries.get(side, iy, ix).copy()
        with np.errstate(invalid='ignore'):
            mask = ~(np.isnan(old) & np.isnan(new))
            if mask.any():
                diff = old[mask] != new[mask]
                if np.any(diff):
                    changed = True
                    break

    boundaries.set('top', iy, ix, new_top)
    boundaries.set('bottom', iy, ix, new_bottom)
    boundaries.set('left', iy, ix, new_left)
    boundaries.set('right', iy, ix, new_right)

    return changed


def _watershed_mfd_dask(fractions_da, pour_points_da, chunks_y, chunks_x):
    n_tile_y = len(chunks_y)
    n_tile_x = len(chunks_x)

    # The 8 direction bands must stay in a single chunk: every tile kernel
    # needs all 8 fractions, and the lazy assembly drops axis 0 per block.
    if fractions_da.chunks[0] != (fractions_da.shape[0],):
        fractions_da = fractions_da.rechunk({0: fractions_da.shape[0]})
    # Align pour points to the fractions' spatial tile grid so the lazy
    # assembly can map both arrays block-for-block.
    pour_points_da = pour_points_da.rechunk((chunks_y, chunks_x))

    frac_bdry = _preprocess_mfd_tiles(fractions_da, chunks_y, chunks_x)
    boundaries = BoundaryStore(chunks_y, chunks_x, fill_value=np.nan)

    max_iterations = max(n_tile_y, n_tile_x) * 2 + 10

    for _iteration in range(max_iterations):
        any_changed = False

        for iy in range(n_tile_y):
            for ix in range(n_tile_x):
                c = _process_tile_mfd(
                    iy, ix, fractions_da, pour_points_da,
                    boundaries, frac_bdry,
                    chunks_y, chunks_x, n_tile_y, n_tile_x)
                if c:
                    any_changed = True

        for iy in reversed(range(n_tile_y)):
            for ix in reversed(range(n_tile_x)):
                c = _process_tile_mfd(
                    iy, ix, fractions_da, pour_points_da,
                    boundaries, frac_bdry,
                    chunks_y, chunks_x, n_tile_y, n_tile_x)
                if c:
                    any_changed = True

        if not any_changed:
            break

    boundaries = boundaries.snapshot()

    # Assemble the final result lazily.  The converged boundary snapshot and
    # fraction strips are small, so we capture them in a closure and let
    # map_blocks run the per-tile kernel at compute time.  Nothing here
    # materializes the full output raster during the API call.
    y_starts = np.cumsum((0,) + tuple(chunks_y[:-1]))
    x_starts = np.cumsum((0,) + tuple(chunks_x[:-1]))

    def _tile(chunk, pp_chunk, block_info=None):
        loc = block_info[0]['array-location']
        iy = int(np.searchsorted(y_starts, loc[1][0], side='right')) - 1
        ix = int(np.searchsorted(x_starts, loc[2][0], side='right')) - 1

        chunk = np.asarray(chunk, dtype=np.float64)
        pp_chunk = np.asarray(pp_chunk, dtype=np.float64)
        _, h, w = chunk.shape
        exits = _compute_exit_labels_mfd(
            iy, ix, boundaries, frac_bdry,
            chunks_y, chunks_x, n_tile_y, n_tile_x)
        return _watershed_mfd_tile_kernel(chunk, h, w, pp_chunk, *exits)

    return da.map_blocks(
        _tile, fractions_da, pour_points_da, drop_axis=0,
        dtype=np.float64, meta=np.array((), dtype=np.float64),
    )


# =====================================================================
# Dask+CuPy backend
# =====================================================================

def _watershed_mfd_dask_cupy(fractions_da, pour_points_da, chunks_y, chunks_x):
    import cupy as cp
    fr_np = fractions_da.map_blocks(
        lambda b: b.get(), dtype=fractions_da.dtype,
        meta=np.array((), dtype=fractions_da.dtype),
    )
    pp_np = pour_points_da.map_blocks(
        lambda b: b.get(), dtype=pour_points_da.dtype,
        meta=np.array((), dtype=pour_points_da.dtype),
    )
    result = _watershed_mfd_dask(fr_np, pp_np, chunks_y, chunks_x)
    return result.map_blocks(
        cp.asarray, dtype=result.dtype,
        meta=cp.array((), dtype=result.dtype),
    )


# =====================================================================
# Public API
# =====================================================================

[docs] @supports_dataset def watershed_mfd(flow_dir_mfd: xr.DataArray, pour_points: xr.DataArray, name: str = 'watershed_mfd') -> xr.DataArray: """Label each cell with the pour point it drains to (MFD). Parameters ---------- flow_dir_mfd : xarray.DataArray or xr.Dataset 3D MFD flow direction array of shape (8, H, W). pour_points : xarray.DataArray 2D raster where non-NaN cells are pour points. name : str, default='watershed_mfd' Name of output DataArray. Returns ------- xarray.DataArray or xr.Dataset 2D float64 array where each cell = label of its pour point. NaN for nodata or unreachable cells. """ _validate_raster(flow_dir_mfd, func_name='watershed_mfd', name='flow_dir_mfd', ndim=3) _validate_raster(pour_points, func_name='watershed_mfd', name='pour_points') data = flow_dir_mfd.data pp_data = pour_points.data if data.ndim != 3 or data.shape[0] != 8: raise ValueError( f"flow_dir_mfd must have shape (8, H, W), got {data.shape}") _validate_mfd_fractions(data, func_name='watershed_mfd', name='flow_dir_mfd') _, H, W = data.shape _validate_matching_shape( pour_points, (H, W), func_name='watershed_mfd', name='pour_points', expected_name='flow_dir_mfd') if isinstance(data, np.ndarray): _check_memory(H, W) fr = data.astype(np.float64) pp = np.asarray(pp_data, dtype=np.float64) labels = np.full((H, W), np.nan, dtype=np.float64) state = np.zeros((H, W), dtype=np.int8) for r in range(H): for c in range(W): if fr[0, r, c] != fr[0, r, c]: pass elif pp[r, c] == pp[r, c]: labels[r, c] = pp[r, c] state[r, c] = 3 else: state[r, c] = 1 out = _watershed_mfd_cpu(fr, labels, state, H, W) elif has_cuda_and_cupy() and is_cupy_array(data): _check_gpu_memory(H, W) out = _watershed_mfd_cupy(data, pp_data) elif has_cuda_and_cupy() and is_dask_cupy(flow_dir_mfd): chunks_y = data.chunks[1] chunks_x = data.chunks[2] out = _watershed_mfd_dask_cupy(data, pp_data, chunks_y, chunks_x) elif da is not None and isinstance(data, da.Array): chunks_y = data.chunks[1] chunks_x = data.chunks[2] out = _watershed_mfd_dask(data, pp_data, chunks_y, chunks_x) else: raise TypeError(f"Unsupported array type: {type(data)}") spatial_dims = flow_dir_mfd.dims[1:] coords = {} for d in spatial_dims: if d in flow_dir_mfd.coords: coords[d] = flow_dir_mfd.coords[d] return xr.DataArray(out, name=name, coords=coords, dims=spatial_dims, attrs=flow_dir_mfd.attrs)