"""Watershed delineation and drainage basin labeling.
Two complementary functions:
- ``watershed(flow_dir, pour_points)`` — labels each cell with the
pour point it drains to; cells not reaching any pour point → NaN.
- ``basins(flow_dir)`` — automatically identifies all outlets (pits +
edge-exit cells) and labels every valid cell; no pour points needed.
Both use **downstream tracing with path compression** on CPU — follow
each cell's flow_dir downstream until hitting a labeled cell, then
label the entire traced path. O(N) amortized.
GPU uses iterative label propagation (one hop per iteration).
Dask uses iterative tile sweep with exit-label propagation.
"""
from __future__ import annotations
import numpy as np
import xarray as xr
from numba import cuda
try:
import cupy
except ImportError:
class cupy: # type: ignore[no-redef]
ndarray = False
try:
import dask.array as da
except ImportError:
da = None
from xrspatial.dataset_support import supports_dataset
from xrspatial.hydro._boundary_store import BoundaryStore
from xrspatial.utils import (_validate_raster, cuda_args, has_cuda_and_cupy, is_cupy_array,
is_dask_cupy, ngjit)
def _to_numpy_f64(arr):
"""Convert *arr* to a contiguous numpy float64 array (handles CuPy)."""
if hasattr(arr, 'get'):
arr = arr.get()
return np.asarray(arr, dtype=np.float64)
# =====================================================================
# Memory guards
# =====================================================================
#
# CPU peak working set per pixel for the numpy dispatch + ``_watershed_cpu``:
# fd (float64 cast) -> 8
# labels (float64) -> 8
# state (int8) -> 1
# path_r (int64) -> 8
# path_c (int64) -> 8
# Total ~33 bytes/pixel. The caller's ``flow_dir`` and ``pour_points``
# arrays already live in RAM before dispatch and are not double-counted.
_BYTES_PER_PIXEL = 33
# GPU peak working set per pixel for ``_watershed_cupy``:
# flow_dir_f64 (float64) -> 8
# pp_f64 (float64) -> 8
# labels (float64) -> 8
# state (int32) -> 4
# Total 28 bytes/pixel on the device.
_GPU_BYTES_PER_PIXEL = 28
def _available_memory_bytes():
"""Best-effort estimate of available host memory in bytes."""
try:
with open('/proc/meminfo', 'r') as f:
for line in f:
if line.startswith('MemAvailable:'):
return int(line.split()[1]) * 1024 # kB -> bytes
except (OSError, ValueError, IndexError):
pass
try:
import psutil
return psutil.virtual_memory().available
except (ImportError, AttributeError):
pass
return 2 * 1024 ** 3
def _available_gpu_memory_bytes():
"""Best-effort estimate of free GPU memory in bytes.
Returns 0 if CuPy / CUDA is unavailable or the query fails -- callers
use that as a sentinel meaning "no GPU info, skip the guard".
"""
try:
import cupy as _cp
free, _total = _cp.cuda.runtime.memGetInfo()
return int(free)
except Exception:
return 0
def _check_memory(height, width):
"""Raise MemoryError if the watershed kernel would exceed 50% of RAM."""
required = int(height) * int(width) * _BYTES_PER_PIXEL
available = _available_memory_bytes()
if required > 0.5 * available:
raise MemoryError(
f"watershed_d8 on a {height}x{width} grid requires "
f"~{required / 1e9:.1f} GB of working memory but only "
f"~{available / 1e9:.1f} GB is available. Use a "
f"dask-backed DataArray for out-of-core processing."
)
def _check_gpu_memory(height, width):
"""Raise MemoryError if the CuPy kernel would exceed 50% of free GPU RAM.
Skips the check (returns silently) when ``_available_gpu_memory_bytes``
cannot determine the free memory -- e.g. on hosts without CUDA, where
the kernel will fail at the cupy.asarray boundary anyway.
"""
available = _available_gpu_memory_bytes()
if available <= 0:
return
required = int(height) * int(width) * _GPU_BYTES_PER_PIXEL
if required > 0.5 * available:
raise MemoryError(
f"watershed_d8 on a {height}x{width} grid requires "
f"~{required / 1e9:.1f} GB of GPU working memory but only "
f"~{available / 1e9:.1f} GB is free on the active device. "
f"Use a dask+cupy DataArray for out-of-core processing."
)
# =====================================================================
# Direction helpers
# =====================================================================
@ngjit
def _code_to_offset(code):
"""Return (dy, dx) row/col offset for a D8 direction code."""
c = int(code)
if c == 1:
return 0, 1
elif c == 2:
return 1, 1
elif c == 4:
return 1, 0
elif c == 8:
return 1, -1
elif c == 16:
return 0, -1
elif c == 32:
return -1, -1
elif c == 64:
return -1, 0
elif c == 128:
return -1, 1
return 0, 0
def _code_to_offset_py(code):
"""Pure-Python version for non-numba contexts."""
import math
if isinstance(code, float) and math.isnan(code):
return (0, 0)
c = int(code)
_map = {1: (0, 1), 2: (1, 1), 4: (1, 0), 8: (1, -1),
16: (0, -1), 32: (-1, -1), 64: (-1, 0), 128: (-1, 1)}
return _map.get(c, (0, 0))
# =====================================================================
# CPU kernels
# =====================================================================
@ngjit
def _watershed_cpu(flow_dir, labels, state, h, w):
"""Downstream tracing with path compression for watershed.
Uses a separate ``state`` array to track cell status, so that
pour-point labels can be any float value (including negative).
State values: 0=nodata, 1=unresolved, 2=in-trace, 3=resolved.
On return every reachable cell has state 3 and the label of its
pour point; unreachable cells have state 0 and NaN.
"""
path_r = np.empty(h * w, dtype=np.int64)
path_c = np.empty(h * w, dtype=np.int64)
for r in range(h):
for c in range(w):
if state[r, c] != 1:
continue # already resolved, nodata, or in-trace
# Trace downstream, collecting path
path_len = 0
cr, cc = r, c
found_label = np.nan
found = False
while True:
s = state[cr, cc]
if s == 3:
# Hit a resolved cell (pour point or previously resolved)
found_label = labels[cr, cc]
found = True
break
if s != 1:
# nodata (0) or in-trace (2) → cycle or dead end
break
path_r[path_len] = cr
path_c[path_len] = cc
path_len += 1
state[cr, cc] = 2 # in-trace marker
v = flow_dir[cr, cc]
if v != v: # NaN
break
dy, dx = _code_to_offset(v)
if dy == 0 and dx == 0:
break # pit with no pour point
nr, nc = cr + dy, cc + dx
if nr < 0 or nr >= h or nc < 0 or nc >= w:
break # exits grid
cr, cc = nr, nc
# Assign label to entire traced path
for i in range(path_len):
if found:
labels[path_r[i], path_c[i]] = found_label
state[path_r[i], path_c[i]] = 3
else:
labels[path_r[i], path_c[i]] = np.nan
state[path_r[i], path_c[i]] = 0
return labels
# =====================================================================
# GPU kernels
# =====================================================================
@cuda.jit
def _init_watershed_gpu(flow_dir, pour_points, labels, state, H, W):
"""Pour points → labeled + frontier. NaN → state 0. Others → state 1."""
i, j = cuda.grid(2)
if i >= H or j >= W:
return
v = flow_dir[i, j]
if v != v: # NaN
state[i, j] = 0
labels[i, j] = 0.0
return
pp = pour_points[i, j]
if pp == pp: # not NaN → pour point
labels[i, j] = pp
state[i, j] = 2 # frontier
else:
labels[i, j] = 0.0
state[i, j] = 1 # active
@cuda.jit
def _propagate_labels_gpu(flow_dir, labels, state, changed, H, W):
"""Each active cell follows flow_dir one hop. If downstream is frontier → take label."""
i, j = cuda.grid(2)
if i >= H or j >= W:
return
if state[i, j] != 1:
return
v = flow_dir[i, j]
code = int(v)
dy = 0
dx = 0
if code == 1:
dy, dx = 0, 1
elif code == 2:
dy, dx = 1, 1
elif code == 4:
dy, dx = 1, 0
elif code == 8:
dy, dx = 1, -1
elif code == 16:
dy, dx = 0, -1
elif code == 32:
dy, dx = -1, -1
elif code == 64:
dy, dx = -1, 0
elif code == 128:
dy, dx = -1, 1
if dy == 0 and dx == 0:
return
ni = i + dy
nj = j + dx
if ni < 0 or ni >= H or nj < 0 or nj >= W:
return
if state[ni, nj] == 2: # downstream is frontier
labels[i, j] = labels[ni, nj]
state[i, j] = 3 # newly labeled
cuda.atomic.add(changed, 0, 1)
@cuda.jit
def _advance_frontier_gpu(state, H, W):
"""state 2→4 (done), state 3→2 (new frontier)."""
i, j = cuda.grid(2)
if i >= H or j >= W:
return
s = state[i, j]
if s == 2:
state[i, j] = 4
elif s == 3:
state[i, j] = 2
def _watershed_cupy(flow_dir_data, pour_points_data):
"""GPU driver for watershed."""
import cupy as cp
H, W = flow_dir_data.shape
flow_dir_f64 = flow_dir_data.astype(cp.float64)
pp_f64 = pour_points_data.astype(cp.float64)
labels = cp.zeros((H, W), dtype=cp.float64)
state = cp.zeros((H, W), dtype=cp.int32)
changed = cp.zeros(1, dtype=cp.int32)
griddim, blockdim = cuda_args((H, W))
_init_watershed_gpu[griddim, blockdim](
flow_dir_f64, pp_f64, labels, state, H, W)
max_iter = H * W
for _ in range(max_iter):
changed[0] = 0
_propagate_labels_gpu[griddim, blockdim](
flow_dir_f64, labels, state, changed, H, W)
if int(changed[0]) == 0:
break
_advance_frontier_gpu[griddim, blockdim](state, H, W)
# Unresolved (state=1) and invalid (state=0) → NaN
labels = cp.where((state == 1) | (state == 0), cp.nan, labels)
return labels
# =====================================================================
# Tile kernel for dask iterative path
# =====================================================================
@ngjit
def _watershed_tile_kernel(flow_dir, h, w, pour_points,
exit_top, exit_bottom, exit_left, exit_right,
exit_tl, exit_tr, exit_bl, exit_br):
"""Seeded downstream tracing for a single tile.
Uses a separate state array so pour-point labels can be any float
value (including negative). State: 0=nodata, 1=unresolved,
2=in-trace, 3=resolved.
"""
labels = np.empty((h, w), dtype=np.float64)
state = np.empty((h, w), dtype=np.int8)
# Initialise labels and state
for r in range(h):
for c in range(w):
v = flow_dir[r, c]
if v != v: # NaN
labels[r, c] = np.nan
state[r, c] = 0
continue
pp = pour_points[r, c]
if pp == pp: # not NaN → pour point
labels[r, c] = pp
state[r, c] = 3
continue
labels[r, c] = np.nan
state[r, c] = 1 # unresolved
# Apply exit labels to boundary cells that flow OUT of tile
# Top row: cells flowing north
for c in range(w):
if state[0, c] == 1:
el = exit_top[c]
if el == el: # not NaN → resolved
labels[0, c] = el
state[0, c] = 3
# Bottom row
for c in range(w):
if state[h - 1, c] == 1:
el = exit_bottom[c]
if el == el:
labels[h - 1, c] = el
state[h - 1, c] = 3
# Left column
for r in range(h):
if state[r, 0] == 1:
el = exit_left[r]
if el == el:
labels[r, 0] = el
state[r, 0] = 3
# Right column
for r in range(h):
if state[r, w - 1] == 1:
el = exit_right[r]
if el == el:
labels[r, w - 1] = el
state[r, w - 1] = 3
# Corners
if state[0, 0] == 1 and exit_tl == exit_tl:
labels[0, 0] = exit_tl
state[0, 0] = 3
if state[0, w - 1] == 1 and exit_tr == exit_tr:
labels[0, w - 1] = exit_tr
state[0, w - 1] = 3
if state[h - 1, 0] == 1 and exit_bl == exit_bl:
labels[h - 1, 0] = exit_bl
state[h - 1, 0] = 3
if state[h - 1, w - 1] == 1 and exit_br == exit_br:
labels[h - 1, w - 1] = exit_br
state[h - 1, w - 1] = 3
# Downstream tracing with path compression
path_r = np.empty(h * w, dtype=np.int64)
path_c = np.empty(h * w, dtype=np.int64)
for r in range(h):
for c in range(w):
if state[r, c] != 1:
continue
path_len = 0
cr, cc = r, c
found_label = np.nan
found = False
exit_tile = False
while True:
s = state[cr, cc]
if s == 3:
found_label = labels[cr, cc]
found = True
break
if s != 1:
break
path_r[path_len] = cr
path_c[path_len] = cc
path_len += 1
state[cr, cc] = 2
v = flow_dir[cr, cc]
if v != v:
break
dy, dx = _code_to_offset(v)
if dy == 0 and dx == 0:
break
nr, nc = cr + dy, cc + dx
if nr < 0 or nr >= h or nc < 0 or nc >= w:
# Exits tile — leave as unresolved
exit_tile = True
break
cr, cc = nr, nc
for i in range(path_len):
if found:
labels[path_r[i], path_c[i]] = found_label
state[path_r[i], path_c[i]] = 3
elif exit_tile:
state[path_r[i], path_c[i]] = 1 # still unresolved
else:
labels[path_r[i], path_c[i]] = np.nan
state[path_r[i], path_c[i]] = 0 # dead end
return labels
# =====================================================================
# Dask iterative tile sweep
# =====================================================================
def _preprocess_tiles(flow_dir_da, chunks_y, chunks_x):
"""Extract boundary flow-direction strips into a BoundaryStore."""
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
flow_bdry = BoundaryStore(chunks_y, chunks_x, fill_value=np.nan)
for iy in range(n_tile_y):
for ix in range(n_tile_x):
chunk = flow_dir_da.blocks[iy, ix].compute()
flow_bdry.set('top', iy, ix, _to_numpy_f64(chunk[0, :]))
flow_bdry.set('bottom', iy, ix, _to_numpy_f64(chunk[-1, :]))
flow_bdry.set('left', iy, ix, _to_numpy_f64(chunk[:, 0]))
flow_bdry.set('right', iy, ix, _to_numpy_f64(chunk[:, -1]))
return flow_bdry
def _compute_exit_labels(iy, ix, boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x):
"""Compute exit labels for tile (iy, ix).
For each boundary cell of the current tile, check if its flow_dir
points OUTSIDE the tile. If so, look up the destination cell's
resolved label in the adjacent tile's boundary data.
This is the reverse of flow_accumulation's seed computation:
- flow_accum: "who flows INTO my boundary?" (entry seeds)
- watershed: "where does my boundary cell flow TO?" (exit labels)
"""
tile_h = chunks_y[iy]
tile_w = chunks_x[ix]
exit_top = np.full(tile_w, np.nan)
exit_bottom = np.full(tile_w, np.nan)
exit_left = np.full(tile_h, np.nan)
exit_right = np.full(tile_h, np.nan)
exit_tl = np.nan
exit_tr = np.nan
exit_bl = np.nan
exit_br = np.nan
# --- Top row: cells that flow north/NW/NE out of tile ---
if iy > 0:
fdir_top = flow_bdry.get('top', iy, ix)
nb_labels = boundaries.get('bottom', iy - 1, ix)
for j in range(tile_w):
d = _code_to_offset_py(fdir_top[j])
if d[0] == -1: # flows north
# Destination column in adjacent tile
dj = j + d[1]
if d[1] == 0:
# Cardinal N (64): dest is bottom[iy-1][ix][j]
if 0 <= dj < len(nb_labels):
exit_top[j] = nb_labels[dj]
elif d[1] == -1:
# NW (32): dest is bottom[iy-1][ix][j-1] or corner
if 0 <= dj < len(nb_labels):
exit_top[j] = nb_labels[dj]
elif dj < 0 and ix > 0:
exit_top[j] = boundaries.get('bottom', iy - 1, ix - 1)[-1]
elif d[1] == 1:
# NE (128): dest is bottom[iy-1][ix][j+1] or corner
if 0 <= dj < len(nb_labels):
exit_top[j] = nb_labels[dj]
elif dj >= len(nb_labels) and ix < n_tile_x - 1:
exit_top[j] = boundaries.get('bottom', iy - 1, ix + 1)[0]
# --- Bottom row: cells that flow south/SW/SE out of tile ---
if iy < n_tile_y - 1:
fdir_bot = flow_bdry.get('bottom', iy, ix)
nb_labels = boundaries.get('top', iy + 1, ix)
for j in range(tile_w):
d = _code_to_offset_py(fdir_bot[j])
if d[0] == 1: # flows south
dj = j + d[1]
if d[1] == 0:
if 0 <= dj < len(nb_labels):
exit_bottom[j] = nb_labels[dj]
elif d[1] == 1:
if 0 <= dj < len(nb_labels):
exit_bottom[j] = nb_labels[dj]
elif dj >= len(nb_labels) and ix < n_tile_x - 1:
exit_bottom[j] = boundaries.get('top', iy + 1, ix + 1)[0]
elif d[1] == -1:
if 0 <= dj < len(nb_labels):
exit_bottom[j] = nb_labels[dj]
elif dj < 0 and ix > 0:
exit_bottom[j] = boundaries.get('top', iy + 1, ix - 1)[-1]
# --- Left column: cells that flow west/NW/SW out of tile ---
if ix > 0:
fdir_left = flow_bdry.get('left', iy, ix)
nb_labels = boundaries.get('right', iy, ix - 1)
for r in range(tile_h):
d = _code_to_offset_py(fdir_left[r])
if d[1] == -1: # flows west
dr = r + d[0]
if d[0] == 0:
if 0 <= dr < len(nb_labels):
exit_left[r] = nb_labels[dr]
elif d[0] == -1:
if r == 0:
continue # handled by top-left corner
if 0 <= dr < len(nb_labels):
exit_left[r] = nb_labels[dr]
elif d[0] == 1:
if r == tile_h - 1:
continue
if 0 <= dr < len(nb_labels):
exit_left[r] = nb_labels[dr]
# --- Right column: cells that flow east/NE/SE out of tile ---
if ix < n_tile_x - 1:
fdir_right = flow_bdry.get('right', iy, ix)
nb_labels = boundaries.get('left', iy, ix + 1)
for r in range(tile_h):
d = _code_to_offset_py(fdir_right[r])
if d[1] == 1: # flows east
dr = r + d[0]
if d[0] == 0:
if 0 <= dr < len(nb_labels):
exit_right[r] = nb_labels[dr]
elif d[0] == -1:
if r == 0:
continue
if 0 <= dr < len(nb_labels):
exit_right[r] = nb_labels[dr]
elif d[0] == 1:
if r == tile_h - 1:
continue
if 0 <= dr < len(nb_labels):
exit_right[r] = nb_labels[dr]
# --- Also handle edge-of-grid cells that flow off grid ---
# Top row with no tile above
if iy == 0:
fdir_top = flow_bdry.get('top', iy, ix)
for j in range(tile_w):
d = _code_to_offset_py(fdir_top[j])
if d[0] == -1:
exit_top[j] = np.nan # flows off grid
# Bottom row with no tile below
if iy == n_tile_y - 1:
fdir_bot = flow_bdry.get('bottom', iy, ix)
for j in range(tile_w):
d = _code_to_offset_py(fdir_bot[j])
if d[0] == 1:
exit_bottom[j] = np.nan
# Left col with no tile left
if ix == 0:
fdir_left = flow_bdry.get('left', iy, ix)
for r in range(tile_h):
d = _code_to_offset_py(fdir_left[r])
if d[1] == -1:
exit_left[r] = np.nan
# Right col with no tile right
if ix == n_tile_x - 1:
fdir_right = flow_bdry.get('right', iy, ix)
for r in range(tile_h):
d = _code_to_offset_py(fdir_right[r])
if d[1] == 1:
exit_right[r] = np.nan
# --- Diagonal corners ---
# TL corner of this tile (0,0) flows to tile (iy-1, ix-1)?
fdir_tl = flow_bdry.get('top', iy, ix)[0]
d = _code_to_offset_py(fdir_tl)
if d == (-1, -1): # NW
if iy > 0 and ix > 0:
exit_tl = boundaries.get('bottom', iy - 1, ix - 1)[-1]
else:
exit_tl = np.nan
# TR corner (0, w-1)
fdir_tr = flow_bdry.get('top', iy, ix)[-1]
d = _code_to_offset_py(fdir_tr)
if d == (-1, 1): # NE
if iy > 0 and ix < n_tile_x - 1:
exit_tr = boundaries.get('bottom', iy - 1, ix + 1)[0]
else:
exit_tr = np.nan
# BL corner (h-1, 0)
fdir_bl = flow_bdry.get('bottom', iy, ix)[0]
d = _code_to_offset_py(fdir_bl)
if d == (1, -1): # SW
if iy < n_tile_y - 1 and ix > 0:
exit_bl = boundaries.get('top', iy + 1, ix - 1)[-1]
else:
exit_bl = np.nan
# BR corner (h-1, w-1)
fdir_br = flow_bdry.get('bottom', iy, ix)[-1]
d = _code_to_offset_py(fdir_br)
if d == (1, 1): # SE
if iy < n_tile_y - 1 and ix < n_tile_x - 1:
exit_br = boundaries.get('top', iy + 1, ix + 1)[0]
else:
exit_br = np.nan
return (exit_top, exit_bottom, exit_left, exit_right,
exit_tl, exit_tr, exit_bl, exit_br)
def _process_tile_watershed(iy, ix, flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x):
"""Run seeded tracing on one tile; update boundaries in-place.
Returns whether any boundary label changed (bool).
"""
chunk = np.asarray(
flow_dir_da.blocks[iy, ix].compute(), dtype=np.float64)
pp_chunk = np.asarray(
pour_points_da.blocks[iy, ix].compute(), dtype=np.float64)
h, w = chunk.shape
exits = _compute_exit_labels(
iy, ix, boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
result = _watershed_tile_kernel(chunk, h, w, pp_chunk, *exits)
# Extract new boundary labels
new_top = result[0, :].copy()
new_bottom = result[-1, :].copy()
new_left = result[:, 0].copy()
new_right = result[:, -1].copy()
# Check for changes
changed = False
for side, new in (('top', new_top), ('bottom', new_bottom),
('left', new_left), ('right', new_right)):
old = boundaries.get(side, iy, ix).copy()
with np.errstate(invalid='ignore'):
# Changed if any value differs (considering NaN==NaN as same)
mask = ~(np.isnan(old) & np.isnan(new))
if mask.any():
diff = old[mask] != new[mask]
if np.any(diff):
changed = True
break
boundaries.set('top', iy, ix, new_top)
boundaries.set('bottom', iy, ix, new_bottom)
boundaries.set('left', iy, ix, new_left)
boundaries.set('right', iy, ix, new_right)
return changed
def _watershed_dask_iterative(flow_dir_da, pour_points_da):
"""Iterative boundary-propagation for watershed on dask arrays."""
chunks_y = flow_dir_da.chunks[0]
chunks_x = flow_dir_da.chunks[1]
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
flow_bdry = _preprocess_tiles(flow_dir_da, chunks_y, chunks_x)
flow_bdry = flow_bdry.snapshot() # read-only from here; release temp files
boundaries = BoundaryStore(chunks_y, chunks_x, fill_value=np.nan)
max_iterations = max(n_tile_y, n_tile_x) * 2 + 10
for _iteration in range(max_iterations):
any_changed = False
# Forward sweep
for iy in range(n_tile_y):
for ix in range(n_tile_x):
c = _process_tile_watershed(
iy, ix, flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
if c:
any_changed = True
# Backward sweep
for iy in reversed(range(n_tile_y)):
for ix in reversed(range(n_tile_x)):
c = _process_tile_watershed(
iy, ix, flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
if c:
any_changed = True
if not any_changed:
break
# Snapshot converged boundaries before assembly (releases temp files)
boundaries = boundaries.snapshot()
return _assemble_watershed(flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
def _assemble_watershed(flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x):
"""Build lazy dask array by re-running tiles with converged exit labels."""
def _tile_fn(flow_dir_block, pp_block, block_info=None):
if block_info is None or 0 not in block_info:
return np.full(flow_dir_block.shape, np.nan, dtype=np.float64)
iy, ix = block_info[0]['chunk-location']
h, w = flow_dir_block.shape
exits = _compute_exit_labels(
iy, ix, boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
result = _watershed_tile_kernel(
np.asarray(flow_dir_block, dtype=np.float64),
h, w,
np.asarray(pp_block, dtype=np.float64),
*exits)
return result
return da.map_blocks(
_tile_fn,
flow_dir_da, pour_points_da,
dtype=np.float64,
meta=np.array((), dtype=np.float64),
)
def _watershed_tile_cupy(flow_dir_data, pour_points_data,
exit_top, exit_bottom, exit_left, exit_right,
exit_tl, exit_tr, exit_bl, exit_br):
"""GPU seeded watershed for a single tile.
Uses GPU label propagation with exit labels injected at boundary
cells before iteration.
"""
import cupy as cp
H, W = flow_dir_data.shape
flow_dir_f64 = flow_dir_data.astype(cp.float64)
pp_f64 = pour_points_data.astype(cp.float64)
labels = cp.zeros((H, W), dtype=cp.float64)
state = cp.zeros((H, W), dtype=cp.int32)
changed = cp.zeros(1, dtype=cp.int32)
griddim, blockdim = cuda_args((H, W))
_init_watershed_gpu[griddim, blockdim](
flow_dir_f64, pp_f64, labels, state, H, W)
# Inject exit labels at boundary cells where active (state==1)
# and exit label is resolved (not NaN, >= 0).
exit_top_cp = cp.asarray(exit_top)
m = (state[0, :] == 1) & ~cp.isnan(exit_top_cp)
labels[0, :] = cp.where(m, exit_top_cp, labels[0, :])
state[0, :] = cp.where(m, 2, state[0, :])
exit_bot_cp = cp.asarray(exit_bottom)
m = (state[H - 1, :] == 1) & ~cp.isnan(exit_bot_cp)
labels[H - 1, :] = cp.where(m, exit_bot_cp, labels[H - 1, :])
state[H - 1, :] = cp.where(m, 2, state[H - 1, :])
exit_left_cp = cp.asarray(exit_left)
m = (state[:, 0] == 1) & ~cp.isnan(exit_left_cp)
labels[:, 0] = cp.where(m, exit_left_cp, labels[:, 0])
state[:, 0] = cp.where(m, 2, state[:, 0])
exit_right_cp = cp.asarray(exit_right)
m = (state[:, W - 1] == 1) & ~cp.isnan(exit_right_cp)
labels[:, W - 1] = cp.where(m, exit_right_cp, labels[:, W - 1])
state[:, W - 1] = cp.where(m, 2, state[:, W - 1])
# Corner exit labels
for r, c, val in [(0, 0, exit_tl), (0, W - 1, exit_tr),
(H - 1, 0, exit_bl), (H - 1, W - 1, exit_br)]:
if val == val and int(state[r, c]) == 1:
labels[r, c] = val
state[r, c] = 2
max_iter = H * W
for _ in range(max_iter):
changed[0] = 0
_propagate_labels_gpu[griddim, blockdim](
flow_dir_f64, labels, state, changed, H, W)
if int(changed[0]) == 0:
break
_advance_frontier_gpu[griddim, blockdim](state, H, W)
labels = cp.where((state == 1) | (state == 0), cp.nan, labels)
return labels
def _process_tile_watershed_cupy(iy, ix, flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x):
"""Run seeded GPU watershed on one tile; update boundaries."""
import cupy as cp
chunk = cp.asarray(
flow_dir_da.blocks[iy, ix].compute(), dtype=cp.float64)
pp_chunk = cp.asarray(
pour_points_da.blocks[iy, ix].compute(), dtype=cp.float64)
exits = _compute_exit_labels(
iy, ix, boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
result = _watershed_tile_cupy(chunk, pp_chunk, *exits)
new_top = result[0, :].get().copy()
new_bottom = result[-1, :].get().copy()
new_left = result[:, 0].get().copy()
new_right = result[:, -1].get().copy()
changed = False
for side, new in (('top', new_top), ('bottom', new_bottom),
('left', new_left), ('right', new_right)):
old = boundaries.get(side, iy, ix).copy()
with np.errstate(invalid='ignore'):
mask = ~(np.isnan(old) & np.isnan(new))
if mask.any():
diff = old[mask] != new[mask]
if np.any(diff):
changed = True
break
boundaries.set('top', iy, ix, new_top)
boundaries.set('bottom', iy, ix, new_bottom)
boundaries.set('left', iy, ix, new_left)
boundaries.set('right', iy, ix, new_right)
return changed
def _assemble_watershed_cupy(flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x):
"""Build lazy dask+cupy array using GPU watershed tile kernel."""
import cupy as cp
def _tile_fn(flow_dir_block, pp_block, block_info=None):
if block_info is None or 0 not in block_info:
return cp.full(flow_dir_block.shape, cp.nan, dtype=cp.float64)
iy, ix = block_info[0]['chunk-location']
exits = _compute_exit_labels(
iy, ix, boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
return _watershed_tile_cupy(
cp.asarray(flow_dir_block, dtype=cp.float64),
cp.asarray(pp_block, dtype=cp.float64),
*exits)
return da.map_blocks(
_tile_fn,
flow_dir_da, pour_points_da,
dtype=np.float64,
meta=cp.array((), dtype=cp.float64),
)
def _watershed_dask_cupy(flow_dir_da, pour_points_da):
"""Dask+CuPy watershed: native GPU processing per tile."""
chunks_y = flow_dir_da.chunks[0]
chunks_x = flow_dir_da.chunks[1]
n_tile_y = len(chunks_y)
n_tile_x = len(chunks_x)
flow_bdry = _preprocess_tiles(flow_dir_da, chunks_y, chunks_x)
flow_bdry = flow_bdry.snapshot()
boundaries = BoundaryStore(chunks_y, chunks_x, fill_value=np.nan)
max_iterations = max(n_tile_y, n_tile_x) * 2 + 10
for _iteration in range(max_iterations):
any_changed = False
for iy in range(n_tile_y):
for ix in range(n_tile_x):
c = _process_tile_watershed_cupy(
iy, ix, flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
if c:
any_changed = True
for iy in reversed(range(n_tile_y)):
for ix in reversed(range(n_tile_x)):
c = _process_tile_watershed_cupy(
iy, ix, flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
if c:
any_changed = True
if not any_changed:
break
boundaries = boundaries.snapshot()
return _assemble_watershed_cupy(flow_dir_da, pour_points_da,
boundaries, flow_bdry,
chunks_y, chunks_x, n_tile_y, n_tile_x)
# =====================================================================
# Public API
# =====================================================================
[docs]
@supports_dataset
def watershed_d8(flow_dir: xr.DataArray,
pour_points: xr.DataArray,
name: str = 'watershed') -> xr.DataArray:
"""Label each cell with the pour point it drains to.
Follows each cell downstream through the D8 flow direction grid
until it reaches a pour point. The cell is then labeled with that
pour point's value. Cells that do not reach any pour point are
assigned NaN.
Parameters
----------
flow_dir : xarray.DataArray or xr.Dataset
2D D8 flow direction grid
(codes 0/1/2/4/8/16/32/64/128; NaN for nodata).
pour_points : xarray.DataArray
2D raster where non-NaN cells are pour points and their
values become the labels. Must have the same shape as
``flow_dir``.
name : str, default='watershed'
Name of output DataArray.
Returns
-------
xarray.DataArray or xr.Dataset
2D float64 array where each cell = label of its pour point.
NaN for nodata or cells not reaching any pour point.
"""
_validate_raster(flow_dir, func_name='watershed', name='flow_dir')
_validate_raster(pour_points, func_name='watershed', name='pour_points')
data = flow_dir.data
pp_data = pour_points.data
if isinstance(data, np.ndarray):
_check_memory(*data.shape)
fd = data.astype(np.float64)
pp = np.asarray(pp_data, dtype=np.float64)
h, w = fd.shape
# Init labels and state: pour points → resolved (state 3),
# NaN flow_dir → nodata (state 0), others → unresolved (state 1)
labels = np.full((h, w), np.nan, dtype=np.float64)
state = np.zeros((h, w), dtype=np.int8)
for r in range(h):
for c in range(w):
if fd[r, c] != fd[r, c]: # NaN
pass # state 0, label NaN
elif pp[r, c] == pp[r, c]: # not NaN → pour point
labels[r, c] = pp[r, c]
state[r, c] = 3
else:
state[r, c] = 1 # unresolved
out = _watershed_cpu(fd, labels, state, h, w)
elif has_cuda_and_cupy() and is_cupy_array(data):
_check_gpu_memory(*data.shape)
out = _watershed_cupy(data, pp_data)
elif has_cuda_and_cupy() and is_dask_cupy(flow_dir):
out = _watershed_dask_cupy(data, pp_data)
elif da is not None and isinstance(data, da.Array):
out = _watershed_dask_iterative(data, pp_data)
else:
raise TypeError(f"Unsupported array type: {type(data)}")
return xr.DataArray(out,
name=name,
coords=flow_dir.coords,
dims=flow_dir.dims,
attrs=flow_dir.attrs)
[docs]
def basins_d8(flow_dir, name='basins'):
"""Backward-compatible wrapper; use :func:`basin` instead."""
import warnings
warnings.warn(
"basins_d8 is deprecated; use basin (basin_d8) instead.",
DeprecationWarning,
stacklevel=2,
)
from .basin_d8 import basin_d8
return basin_d8(flow_dir, name=name)