Source code for xrspatial.preview

"""Memory-safe raster preview via downsampling."""

import numpy as np
import xarray as xr

from xrspatial.dataset_support import supports_dataset
from xrspatial.utils import (
    _validate_raster,
    has_cuda_and_cupy,
    is_cupy_array,
)

_COARSEN_METHODS = ('mean', 'median', 'max', 'min')
_METHODS = (*_COARSEN_METHODS, 'nearest', 'bilinear')


def _nan_full(oh, ow, block):
    """NaN-filled ``(oh, ow)`` array matching *block*'s type and dtype."""
    try:
        import cupy
        if isinstance(block, cupy.ndarray):
            return cupy.full((oh, ow), np.nan, dtype=block.dtype)
    except ImportError:
        pass
    return np.full((oh, ow), np.nan, dtype=block.dtype)


def _is_all_nan(block):
    """Fast all-NaN check for float arrays.  Works with numpy and cupy.

    Uses ``nanmax`` (single pass, no intermediate boolean array) after
    a near-free first-element guard that exits immediately for the
    common non-NaN case.
    """
    if block.dtype.kind != 'f' or block.size == 0:
        return False
    first = block.flat[0]
    if first == first:          # not NaN → definitely not all-NaN
        return False
    # nanmax returns NaN iff every element is NaN.
    try:
        import cupy
        if isinstance(block, cupy.ndarray):
            with np.errstate(invalid='ignore'):
                return bool(cupy.isnan(cupy.nanmax(block)))
    except ImportError:
        pass
    import warnings
    with np.errstate(invalid='ignore'), warnings.catch_warnings():
        warnings.simplefilter('ignore', RuntimeWarning)
        return bool(np.isnan(np.nanmax(block)))


# ---------------------------------------------------------------------------
# Block reduction (numpy / cupy)
# ---------------------------------------------------------------------------

def _block_reduce(data, factor_y, factor_x, method):
    """Trim-reshape-reduce a 2D array.  Works with numpy and cupy."""
    oh = data.shape[0] // factor_y
    ow = data.shape[1] // factor_x
    if oh == 0 or ow == 0:
        if has_cuda_and_cupy() and is_cupy_array(data):
            import cupy
            return cupy.empty((oh, ow), dtype=data.dtype)
        return np.empty((oh, ow), dtype=data.dtype)
    if _is_all_nan(data):
        return _nan_full(oh, ow, data)
    trimmed = data[:oh * factor_y, :ow * factor_x]
    blocks = trimmed.reshape(oh, factor_y, ow, factor_x)
    if method == 'median':
        flat = blocks.transpose(0, 2, 1, 3).reshape(oh, ow, -1)
        if has_cuda_and_cupy() and is_cupy_array(data):
            import cupy
            return cupy.median(flat, axis=2)
        return np.median(flat, axis=2).astype(data.dtype)
    return getattr(blocks, method)(axis=(1, 3))


def _reduce_local(agg, factor_y, factor_x, method, y_dim, x_dim):
    """Block reduction for in-memory (numpy / cupy) DataArrays."""
    out_data = _block_reduce(agg.data, factor_y, factor_x, method)
    oh, ow = out_data.shape
    coords = {}
    if y_dim in agg.coords:
        coords[y_dim] = _interpolate_coords(agg.coords[y_dim], oh)
    if x_dim in agg.coords:
        coords[x_dim] = _interpolate_coords(agg.coords[x_dim], ow)
    return xr.DataArray(
        out_data, dims=[y_dim, x_dim], coords=coords, attrs=agg.attrs,
    )


# ---------------------------------------------------------------------------
# Dask block reduction via map_blocks
# ---------------------------------------------------------------------------

def _snap_factor(chunk_size, factor):
    """Return the divisor of *chunk_size* closest to *factor*.

    When the reduction factor evenly divides every chunk, no rechunking
    is needed and the dask graph stays minimal.  The output dimensions
    may overshoot the target; a cheap in-memory second pass corrects
    that afterwards.
    """
    if chunk_size % factor == 0:
        return factor
    best = 1          # 1 always divides; guarantees a result
    best_dist = abs(1 - factor)
    for d in range(2, int(chunk_size ** 0.5) + 1):
        if chunk_size % d == 0:
            for candidate in (d, chunk_size // d):
                dist = abs(candidate - factor)
                if dist < best_dist:
                    best_dist = dist
                    best = candidate
    return best


def _reduce_dask(agg, factor_y, factor_x, method, y_dim, x_dim):
    """Block reduction for dask-backed DataArrays.

    Uses ``dask.array.map_blocks`` so each chunk is independently
    trim-reshape-reduced in a single task.  This produces one graph
    layer on top of the input instead of the five layers that
    ``xarray.coarsen`` generates (reshape, mean_chunk, mean_agg, …).
    """
    import dask.array as da

    data = agg.data

    out_chunks_y = tuple(c // factor_y for c in data.chunks[0])
    out_chunks_x = tuple(c // factor_x for c in data.chunks[1])

    # Captured by the closure; serialised into the task graph.
    _fy, _fx, _m = factor_y, factor_x, method

    def _reduce_block(block):
        oh = block.shape[0] // _fy
        ow = block.shape[1] // _fx
        if oh == 0 or ow == 0:
            return np.empty((oh, ow), dtype=block.dtype)
        if _is_all_nan(block):
            return _nan_full(oh, ow, block)
        trimmed = block[:oh * _fy, :ow * _fx]
        blocks = trimmed.reshape(oh, _fy, ow, _fx)
        if _m == 'median':
            flat = blocks.transpose(0, 2, 1, 3).reshape(oh, ow, -1)
            return np.median(flat, axis=2).astype(block.dtype)
        return getattr(blocks, _m)(axis=(1, 3))

    result_data = da.map_blocks(
        _reduce_block, data,
        dtype=agg.dtype,
        chunks=(out_chunks_y, out_chunks_x),
    )

    out_h = sum(out_chunks_y)
    out_w = sum(out_chunks_x)
    coords = {}
    if y_dim in agg.coords:
        coords[y_dim] = _interpolate_coords(agg.coords[y_dim], out_h)
    if x_dim in agg.coords:
        coords[x_dim] = _interpolate_coords(agg.coords[x_dim], out_w)

    return xr.DataArray(
        result_data, dims=[y_dim, x_dim], coords=coords, attrs=agg.attrs,
    )


# ---------------------------------------------------------------------------
# Bilinear helpers
# ---------------------------------------------------------------------------

def _bilinear_numpy(data, out_h, out_w):
    """Bilinear interpolation on a 2D numpy array."""
    from scipy.ndimage import zoom
    return zoom(data, (out_h / data.shape[0], out_w / data.shape[1]), order=1)


def _bilinear_cupy(data, out_h, out_w):
    """Bilinear interpolation on a 2D cupy array."""
    from cupyx.scipy.ndimage import zoom
    return zoom(data, (out_h / data.shape[0], out_w / data.shape[1]), order=1)


def _bilinear_dask(agg, out_h, out_w, y_dim, x_dim):
    """Memory-safe bilinear interpolation for dask-backed arrays.

    Each chunk is independently zoomed via ``map_blocks``, keeping peak
    memory bounded by the largest input chunk.
    """
    import dask.array as da

    h, w = agg.shape
    in_chunks_y = agg.data.chunks[0]
    in_chunks_x = agg.data.chunks[1]

    # Integer output chunk sizes that sum to exactly out_h / out_w.
    cum_y = np.cumsum([0] + list(in_chunks_y))
    edges_y = np.round(cum_y * (out_h / h)).astype(int)
    out_chunks_y = tuple(int(v) for v in np.diff(edges_y))

    cum_x = np.cumsum([0] + list(in_chunks_x))
    edges_x = np.round(cum_x * (out_w / w)).astype(int)
    out_chunks_x = tuple(int(v) for v in np.diff(edges_x))

    _ocy = out_chunks_y
    _ocx = out_chunks_x

    def _zoom_block(block, block_info=None):
        from scipy.ndimage import zoom

        if block_info is None or block.size == 0:
            return block
        yi, xi = block_info[0]['chunk-location']
        th, tw = _ocy[yi], _ocx[xi]
        if th == 0 or tw == 0:
            return np.empty((th, tw), dtype=block.dtype)
        return zoom(
            block,
            (th / block.shape[0], tw / block.shape[1]),
            order=1,
        )

    result_data = da.map_blocks(
        _zoom_block, agg.data,
        dtype=agg.dtype,
        chunks=(out_chunks_y, out_chunks_x),
    )

    coords = {}
    if y_dim in agg.coords:
        coords[y_dim] = _interpolate_coords(agg.coords[y_dim], out_h)
    if x_dim in agg.coords:
        coords[x_dim] = _interpolate_coords(agg.coords[x_dim], out_w)

    return xr.DataArray(
        result_data, dims=[y_dim, x_dim], coords=coords, attrs=agg.attrs,
    )


def _preview_bilinear(agg, out_h, out_w, y_dim, x_dim):
    """Dispatch bilinear interpolation across backends."""
    import dask.array as da

    if isinstance(agg.data, da.Array):
        return _bilinear_dask(agg, out_h, out_w, y_dim, x_dim)
    elif has_cuda_and_cupy() and is_cupy_array(agg.data):
        out_data = _bilinear_cupy(agg.data, out_h, out_w)
    else:
        out_data = _bilinear_numpy(agg.data, out_h, out_w)

    coords = {}
    if y_dim in agg.coords:
        coords[y_dim] = _interpolate_coords(agg.coords[y_dim], out_h)
    if x_dim in agg.coords:
        coords[x_dim] = _interpolate_coords(agg.coords[x_dim], out_w)

    return xr.DataArray(
        out_data, dims=[y_dim, x_dim], coords=coords, attrs=agg.attrs,
    )


# ---------------------------------------------------------------------------
# Coordinate interpolation
# ---------------------------------------------------------------------------

def _interpolate_coords(coords, n_out):
    """Interpolate coordinate values to *n_out* evenly-spaced index positions.

    Works for both increasing and decreasing coordinates because
    interpolation is done in index-space.
    """
    vals = coords.values
    if len(vals) <= 1 or n_out <= 1:
        return vals[:max(n_out, 1)]
    indices = np.linspace(0, len(vals) - 1, n_out)
    return np.interp(indices, np.arange(len(vals)), vals.astype(np.float64))


# ---------------------------------------------------------------------------
# Second-pass refinement
# ---------------------------------------------------------------------------

def _refine_to_target(result, target_h, target_w, y_dim, x_dim):
    """Subsample a small in-memory result to exact target dimensions.

    When snap-based dask reduction overshoots the requested size (e.g.
    1680 instead of 1000), this picks evenly-spaced rows/columns to
    hit the target exactly.  The intermediate is always small, so this
    is negligible.
    """
    rh = result.sizes[y_dim]
    rw = result.sizes[x_dim]
    out_h = min(rh, target_h)
    out_w = min(rw, target_w)
    if out_h == rh and out_w == rw:
        return result
    idx_y = np.linspace(0, rh - 1, out_h, dtype=int)
    idx_x = np.linspace(0, rw - 1, out_w, dtype=int)
    out_data = result.data[np.ix_(idx_y, idx_x)]
    coords = {}
    if y_dim in result.coords:
        coords[y_dim] = _interpolate_coords(result.coords[y_dim], out_h)
    if x_dim in result.coords:
        coords[x_dim] = _interpolate_coords(result.coords[x_dim], out_w)
    return xr.DataArray(
        out_data, dims=[y_dim, x_dim], coords=coords, attrs=result.attrs,
    )


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

[docs] @supports_dataset def preview(agg, width=1000, height=None, method='mean', name='preview'): """Downsample a raster to target pixel dimensions. For dask-backed arrays, the operation is lazy: each chunk is reduced independently, so peak memory is bounded by the largest chunk plus the small output array. A 30 TB raster can be previewed at 1000x1000 with only a few MB of RAM. Parameters ---------- agg : xr.DataArray Input raster (2D). width : int, default 1000 Target width in pixels. height : int, optional Target height in pixels. If not provided, computed from *width* preserving the aspect ratio of *agg*. method : str, default 'mean' Downsampling method. One of: - ``'mean'``: block averaging. - ``'median'``: block median. - ``'max'``: block maximum. - ``'min'``: block minimum. - ``'nearest'``: stride-based subsampling (fastest, no smoothing). - ``'bilinear'``: bilinear interpolation via ``scipy.ndimage.zoom``. name : str, default 'preview' Name for the output DataArray. Returns ------- xr.DataArray Downsampled raster with updated coordinates. """ _validate_raster(agg, func_name='preview', ndim=2) if method not in _METHODS: raise ValueError( f"method must be one of {_METHODS!r}, got {method!r}" ) h = agg.sizes[agg.dims[0]] w = agg.sizes[agg.dims[1]] if height is None: height = max(1, round(width * h / w)) factor_y = max(1, h // height) factor_x = max(1, w // width) if factor_y <= 1 and factor_x <= 1: return agg y_dim = agg.dims[0] x_dim = agg.dims[1] # Save the original targets before snap may widen them. target_h, target_w = height, width # For dask arrays, snap each factor to the nearest divisor of the # chunk size so that every chunk divides evenly and no rechunking # is needed. The output dimensions may overshoot the target; a # cheap second pass corrects that below. try: import dask.array as da if isinstance(agg.data, da.Array): factor_y = _snap_factor(agg.data.chunksize[0], factor_y) factor_x = _snap_factor(agg.data.chunksize[1], factor_x) height = h // factor_y width = w // factor_x except ImportError: pass # Pre-trim the input to an exact multiple of the factors so the # reduce output is exactly (height, width) without a post-reduce # trim. On dask arrays this only touches boundary chunks, avoiding # two extra getitem layers over the (much larger) output grid. trim_h = height * factor_y trim_w = width * factor_x trim = {} if trim_h < h: trim[y_dim] = slice(0, trim_h) if trim_w < w: trim[x_dim] = slice(0, trim_w) if trim: agg = agg.isel(trim) if method == 'nearest': result = agg.isel( {y_dim: slice(None, None, factor_y), x_dim: slice(None, None, factor_x)} ) elif method == 'bilinear': result = _preview_bilinear(agg, height, width, y_dim, x_dim) else: # mean / median / max / min try: import dask.array as da is_dask = isinstance(agg.data, da.Array) except ImportError: is_dask = False if is_dask: result = _reduce_dask( agg, factor_y, factor_x, method, y_dim, x_dim, ) else: result = _reduce_local( agg, factor_y, factor_x, method, y_dim, x_dim, ) result.name = name result.attrs = agg.attrs # Second pass: if snap overshot the target, compute the small # intermediate and subsample to exact dimensions. if (result.sizes[y_dim] > target_h or result.sizes[x_dim] > target_w): try: result = result.compute() except (AttributeError, TypeError): pass result = _refine_to_target( result, target_h, target_w, y_dim, x_dim, ) result.name = name return result