Source code for xrspatial.sieve

"""Sieve filter for removing small raster clumps.

Given a categorical raster and a pixel-count threshold, replaces
connected regions smaller than the threshold with the value of
their largest spatial neighbor that is already at or above the
threshold.  Matches the single-pass semantics of GDAL's
``GDALSieveFilter`` / ``rasterio.features.sieve``.

Pairs with classification functions (``natural_breaks``,
``reclassify``, etc.) and ``polygonize`` for cleaning results
before vectorization.

Supports all four backends: numpy, cupy, dask+numpy, dask+cupy.
"""

from __future__ import annotations

from collections import defaultdict
from typing import Sequence

import numpy as np
import xarray as xr
from xarray import DataArray

try:
    import cupy
except ImportError:

    class cupy:
        ndarray = False


try:
    import dask.array as da
except ImportError:
    da = None

from xrspatial.utils import (
    _validate_raster,
    has_cuda_and_cupy,
    is_cupy_array,
    is_dask_cupy,
    ngjit,
)



# ---------------------------------------------------------------------------
# Numba union-find labeling
# ---------------------------------------------------------------------------


@ngjit
def _uf_find(parent, x):
    """Find root of *x* with path halving."""
    while parent[x] != x:
        parent[x] = parent[parent[x]]
        x = parent[x]
    return x


@ngjit
def _uf_union(parent, rank, a, b):
    """Union by rank."""
    ra = _uf_find(parent, a)
    rb = _uf_find(parent, b)
    if ra == rb:
        return
    if rank[ra] < rank[rb]:
        parent[ra] = rb
    elif rank[ra] > rank[rb]:
        parent[rb] = ra
    else:
        parent[rb] = ra
        rank[ra] += 1


@ngjit
def _label_connected(data, valid, neighborhood):
    """Single-pass connected-component labeling via union-find.

    Labels connected regions of same-value pixels in one O(n) pass,
    replacing the previous approach of calling ``scipy.ndimage.label``
    once per unique raster value.

    Uses int32 indices internally, so the raster must have fewer than
    ~2.1 billion pixels (roughly 46 000 x 46 000).

    Returns
    -------
    region_map : ndarray of int32 (2D)
        Each pixel mapped to its region id (0 = nodata).
    region_val : ndarray of float64 (1D)
        Original raster value for each region id.
    n_regions : int
        Total number of regions + 1 (length of *region_val*).
    """
    rows = data.shape[0]
    cols = data.shape[1]
    n = rows * cols
    parent = np.arange(n, dtype=np.int32)
    rank = np.zeros(n, dtype=np.int32)

    for r in range(rows):
        for c in range(cols):
            if not valid[r, c]:
                continue
            idx = r * cols + c
            val = data[r, c]

            # Check left (already visited)
            if c > 0 and valid[r, c - 1] and data[r, c - 1] == val:
                _uf_union(parent, rank, idx, idx - 1)
            # Check up (already visited)
            if r > 0 and valid[r - 1, c] and data[r - 1, c] == val:
                _uf_union(parent, rank, idx, (r - 1) * cols + c)

            if neighborhood == 8:
                if (
                    r > 0
                    and c > 0
                    and valid[r - 1, c - 1]
                    and data[r - 1, c - 1] == val
                ):
                    _uf_union(parent, rank, idx, (r - 1) * cols + (c - 1))
                if (
                    r > 0
                    and c + 1 < cols
                    and valid[r - 1, c + 1]
                    and data[r - 1, c + 1] == val
                ):
                    _uf_union(parent, rank, idx, (r - 1) * cols + (c + 1))

    # --- Count unique regions first so region_val_buf is right-sized ---
    # Reuse rank array (no longer needed after union-find) as root_to_id.
    # This eliminates a separate n-element int32 allocation.
    root_to_id = rank  # alias; rank is dead after union-find
    for i in range(n):
        root_to_id[i] = 0  # clear

    n_regions = 0
    for i in range(n):
        r = i // cols
        c = i % cols
        if not valid[r, c]:
            continue
        root = _uf_find(parent, i)
        if root_to_id[root] == 0:
            root_to_id[root] = 1  # mark as seen
            n_regions += 1

    # Allocate region_val_buf at actual region count, not pixel count.
    # For a 46K x 46K raster with 100K regions this saves ~16 GB.
    region_val_buf = np.full(n_regions + 1, np.nan, dtype=np.float64)

    # Assign contiguous region IDs
    region_map_flat = np.zeros(n, dtype=np.int32)
    for i in range(n):
        root_to_id[i] = 0  # clear for ID assignment
    next_id = 1

    for i in range(n):
        r = i // cols
        c = i % cols
        if not valid[r, c]:
            continue
        root = _uf_find(parent, i)
        if root_to_id[root] == 0:
            root_to_id[root] = next_id
            region_val_buf[next_id] = data[r, c]
            next_id += 1
        region_map_flat[i] = root_to_id[root]

    region_map = region_map_flat.reshape(rows, cols)
    return region_map, region_val_buf[:next_id], next_id


# ---------------------------------------------------------------------------
# Adjacency helpers
# ---------------------------------------------------------------------------


def _build_adjacency(region_map, neighborhood):
    """Build a region adjacency dict from a labeled map.

    Encodes each (lo, hi) region pair as a single int64 so
    deduplication uses fast 1-D ``np.unique`` instead of the slower
    ``np.unique(axis=0)`` on 2-D pair arrays.

    Returns ``{region_id: set_of_neighbor_ids}``.
    """
    max_id = np.int64(region_map.max()) + 1
    encoded_parts: list[np.ndarray] = []

    def _collect(a, b):
        mask = (a > 0) & (b > 0) & (a != b)
        if not mask.any():
            return
        am = a[mask].ravel().astype(np.int64)
        bm = b[mask].ravel().astype(np.int64)
        lo = np.minimum(am, bm)
        hi = np.maximum(am, bm)
        encoded_parts.append(lo * max_id + hi)

    _collect(region_map[:-1, :], region_map[1:, :])
    _collect(region_map[:, :-1], region_map[:, 1:])
    if neighborhood == 8:
        _collect(region_map[:-1, :-1], region_map[1:, 1:])
        _collect(region_map[:-1, 1:], region_map[1:, :-1])

    adjacency: dict[int, set[int]] = defaultdict(set)
    if not encoded_parts:
        return adjacency

    encoded = np.unique(np.concatenate(encoded_parts))
    lo_arr = encoded // max_id
    hi_arr = encoded % max_id

    for a, b in zip(lo_arr.tolist(), hi_arr.tolist()):
        adjacency[a].add(b)
        adjacency[b].add(a)

    return adjacency


# ---------------------------------------------------------------------------
# Memory guards
# ---------------------------------------------------------------------------

# Peak working set for the union-find pass:
#   result copy             8 bytes (float64)
#   parent                  4 bytes (int32)
#   rank / root_to_id       4 bytes (int32, reused)
#   region_map_flat         4 bytes (int32)
#   slack for region_val,
#   region_size, valid mask 8 bytes
# Total ~28 bytes/pixel.  Matches the budget the dask paths already use.
_BYTES_PER_PIXEL = 28


def _available_memory_bytes():
    """Best-effort estimate of available memory in bytes."""
    try:
        with open("/proc/meminfo", "r") as f:
            for line in f:
                if line.startswith("MemAvailable:"):
                    return int(line.split()[1]) * 1024
    except (OSError, ValueError, IndexError):
        pass
    try:
        import psutil

        return psutil.virtual_memory().available
    except (ImportError, AttributeError):
        pass
    return 2 * 1024**3


def _available_gpu_memory_bytes():
    """Best-effort estimate of free GPU memory in bytes.

    Returns 0 when CuPy / CUDA is unavailable or the query fails -- callers
    treat that as "no GPU info, skip the guard".
    """
    try:
        import cupy as _cp

        free, _total = _cp.cuda.runtime.memGetInfo()
        return int(free)
    except Exception:
        return 0


def _check_memory(rows, cols):
    """Raise MemoryError if the union-find pass would exceed 50% of RAM."""
    required = int(rows) * int(cols) * _BYTES_PER_PIXEL
    available = _available_memory_bytes()
    if required > 0.5 * available:
        raise MemoryError(
            f"sieve() on a {rows}x{cols} raster needs "
            f"~{required / 1e9:.1f} GB of working memory but only "
            f"~{available / 1e9:.1f} GB is available.  "
            f"Connected-component labeling is a global operation that "
            f"cannot be chunked.  Consider downsampling or tiling the "
            f"input manually."
        )


def _check_gpu_memory(rows, cols):
    """Raise MemoryError when the CuPy round-trip would not fit.

    The CuPy backend transfers to host and runs the CPU sieve, so the
    host budget still applies; we also check free GPU RAM so a user
    with little VRAM gets a clear error before ``data.get()`` runs.
    Skips silently when the GPU memory query fails.
    """
    _check_memory(rows, cols)
    available = _available_gpu_memory_bytes()
    if available <= 0:
        return
    # Round-trip needs the float64 input on device plus a float64 result.
    required = int(rows) * int(cols) * 16
    if required > 0.5 * available:
        raise MemoryError(
            f"sieve() on a {rows}x{cols} cupy raster needs "
            f"~{required / 1e9:.1f} GB of GPU memory for the round-trip "
            f"but only ~{available / 1e9:.1f} GB is free on the active "
            f"device.  Use a dask+cupy DataArray for out-of-core "
            f"processing or downsample the input."
        )


# ---------------------------------------------------------------------------
# numpy backend
# ---------------------------------------------------------------------------


def _sieve_numpy(data, threshold, neighborhood, skip_values):
    """Single-pass sieve matching GDAL's ``GDALSieveFilter`` semantics.

    A small region is only merged into a neighbor whose size is
    **>= threshold**.  If no such neighbor exists the region stays.
    Regions are processed smallest-first with in-place size updates
    so that earlier merges can grow a neighbor above threshold for
    later ones within the same pass.
    """
    _check_memory(data.shape[0], data.shape[1])
    result = data.astype(np.float64, copy=True)
    is_float = np.issubdtype(data.dtype, np.floating)
    valid = ~np.isnan(result) if is_float else np.ones(result.shape, dtype=bool)
    skip_set = set(skip_values) if skip_values is not None else set()

    region_map, region_val, uid = _label_connected(
        result, valid, neighborhood
    )
    region_size = np.bincount(
        region_map.ravel(), minlength=uid
    ).astype(np.int64)

    small_ids = [
        rid
        for rid in range(1, uid)
        if region_size[rid] < threshold
        and region_val[rid] not in skip_set
    ]
    if not small_ids:
        return result

    adjacency = _build_adjacency(region_map, neighborhood)

    # Process smallest regions first so earlier merges can grow
    # a neighbor above threshold for later candidates.
    small_ids.sort(key=lambda r: region_size[r])

    for rid in small_ids:
        if region_size[rid] == 0 or region_size[rid] >= threshold:
            continue

        neighbors = adjacency.get(rid)
        if not neighbors:
            continue  # surrounded by nodata only

        # Only merge into a neighbor that is already >= threshold.
        valid_neighbors = [
            n for n in neighbors if region_size[n] >= threshold
        ]
        if not valid_neighbors:
            continue

        largest_nid = max(valid_neighbors, key=lambda n: region_size[n])
        mask = region_map == rid
        result[mask] = region_val[largest_nid]

        # Update tracking in place
        region_map[mask] = largest_nid
        region_size[largest_nid] += region_size[rid]
        region_size[rid] = 0

        for n in neighbors:
            if n != largest_nid:
                adjacency[n].discard(rid)
                adjacency[n].add(largest_nid)
                adjacency.setdefault(largest_nid, set()).add(n)
        if largest_nid in adjacency:
            adjacency[largest_nid].discard(rid)
        del adjacency[rid]

    return result


# ---------------------------------------------------------------------------
# cupy backend  (CPU fallback – merge logic is serial)
# ---------------------------------------------------------------------------


def _sieve_cupy(data, threshold, neighborhood, skip_values):
    """CuPy backend: transfer to CPU, sieve, transfer back."""
    import cupy as cp

    _check_gpu_memory(data.shape[0], data.shape[1])
    np_result = _sieve_numpy(
        data.get(), threshold, neighborhood, skip_values
    )
    return cp.asarray(np_result)


# ---------------------------------------------------------------------------
# dask backends
# ---------------------------------------------------------------------------


def _sieve_dask(data, threshold, neighborhood, skip_values):
    """Dask+numpy backend: compute to numpy, sieve, wrap back."""
    avail = _available_memory_bytes()
    n_pixels = np.prod(data.shape)
    # Peak memory: input + result (float64 each) + parent + rank +
    # region_map_flat (int32 each) = 2*8 + 3*4 = 28 bytes/pixel.
    estimated_bytes = n_pixels * 28
    if estimated_bytes > 0.5 * avail:
        raise MemoryError(
            f"sieve() needs ~{estimated_bytes / 1e9:.1f} GB for the full "
            f"array plus CCL bookkeeping, but only ~{avail / 1e9:.1f} GB "
            f"is available.  Connected-component labeling is a global "
            f"operation that cannot be chunked.  Consider downsampling "
            f"or tiling the input manually."
        )

    np_data = data.compute()
    result = _sieve_numpy(
        np_data, threshold, neighborhood, skip_values
    )
    return da.from_array(result, chunks=data.chunks)


def _sieve_dask_cupy(data, threshold, neighborhood, skip_values):
    """Dask+CuPy backend: compute to cupy, sieve via CPU fallback, wrap back."""
    n_pixels = np.prod(data.shape)
    estimated_bytes = n_pixels * 28
    try:
        import cupy as cp

        free_gpu, _total = cp.cuda.Device().mem_info
        if estimated_bytes > 0.5 * free_gpu:
            raise MemoryError(
                f"sieve() needs ~{estimated_bytes / 1e9:.1f} GB for the "
                f"full array plus CCL bookkeeping, but only "
                f"~{free_gpu / 1e9:.1f} GB free GPU memory.  Connected-"
                f"component labeling is a global operation that cannot be "
                f"chunked.  Consider downsampling or tiling the input."
            )
    except (ImportError, AttributeError):
        pass

    cp_data = data.compute()
    result = _sieve_cupy(
        cp_data, threshold, neighborhood, skip_values
    )
    return da.from_array(result, chunks=data.chunks)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] def sieve( raster: xr.DataArray, threshold: int = 10, neighborhood: int = 4, skip_values: Sequence[float] | None = None, name: str = "sieve", ) -> xr.DataArray: """Remove small connected regions from a classified raster. Identifies connected components of same-value pixels and replaces regions smaller than *threshold* pixels with the value of their largest spatial neighbor that is already at or above *threshold*. Regions whose only neighbors are also below *threshold* are left unchanged, matching GDAL's single-pass semantics. NaN pixels are always preserved. Parameters ---------- raster : xr.DataArray 2D classified or categorical raster. threshold : int, default=10 Minimum region size in pixels. Regions with fewer pixels are replaced by their largest neighbor's value. neighborhood : int, default=4 Pixel connectivity: 4 (rook) or 8 (queen). skip_values : sequence of float, optional Category values whose regions are never replaced, regardless of size. These regions can still serve as merge targets for neighboring small regions. name : str, default='sieve' Output DataArray name. Returns ------- xr.DataArray Sieved raster with the same shape, dims, coords, and attrs. Examples -------- .. sourcecode:: python >>> import numpy as np >>> import xarray as xr >>> from xrspatial.sieve import sieve >>> # Classified raster with salt-and-pepper noise >>> arr = np.array([[1, 1, 1, 2, 2], ... [1, 3, 1, 2, 2], ... [1, 1, 1, 2, 2], ... [2, 2, 2, 2, 2], ... [2, 2, 2, 2, 2]], dtype=np.float64) >>> raster = xr.DataArray(arr, dims=['y', 'x']) >>> # Remove regions smaller than 2 pixels >>> result = sieve(raster, threshold=2) >>> print(result.values) [[1. 1. 1. 2. 2.] [1. 1. 1. 2. 2.] [1. 1. 1. 2. 2.] [2. 2. 2. 2. 2.] [2. 2. 2. 2. 2.]] Notes ----- Uses single-pass semantics matching GDAL's ``GDALSieveFilter``. A small region is only merged into a neighbor whose current size is >= *threshold*. If no such neighbor exists the region is left unchanged. This is a global operation: for dask-backed arrays the entire raster is computed into memory before sieving. Connected-component labeling cannot be performed on individual chunks because regions may span chunk boundaries. The CuPy backends use a CPU fallback for the merge step, which is inherently serial. See Also -------- xrspatial.zonal.regions : Connected-component labeling. xrspatial.classify.natural_breaks : Classification that may produce noisy output suitable for sieving. """ _validate_raster(raster, func_name="sieve", name="raster", ndim=2) if neighborhood not in (4, 8): raise ValueError("`neighborhood` must be 4 or 8") if not isinstance(threshold, (int, np.integer)) or threshold < 1: raise ValueError("`threshold` must be a positive integer") data = raster.data if isinstance(data, np.ndarray): out = _sieve_numpy(data, threshold, neighborhood, skip_values) elif has_cuda_and_cupy() and is_cupy_array(data): out = _sieve_cupy(data, threshold, neighborhood, skip_values) elif da is not None and isinstance(data, da.Array): if is_dask_cupy(raster): out = _sieve_dask_cupy( data, threshold, neighborhood, skip_values ) else: out = _sieve_dask(data, threshold, neighborhood, skip_values) else: raise TypeError( f"Unsupported array type {type(data).__name__} for sieve()" ) return DataArray( out, name=name, dims=raster.dims, coords=raster.coords, attrs=raster.attrs, )