Source code for xrspatial.reproject

"""Out-of-core CRS reprojection and multi-raster merge.

Public API
----------
reproject(raster, target_crs, ...)
    Reproject a DataArray to a new coordinate reference system.
merge(rasters, ...)
    Merge multiple DataArrays into a single mosaic.
"""
from __future__ import annotations

import math

import numpy as np
import xarray as xr

from xrspatial.utils import _validate_raster

from ._crs_utils import (
    _detect_band_nodata,
    _detect_nodata,
    _detect_source_crs,
    _resolve_crs,
)
from ._grid import (
    _chunk_bounds,
    _compute_chunk_layout,
    _compute_output_grid,
    _make_output_coords,
    _validate_grid_params,
)
from ._interpolate import (
    _resample_cupy_native,
    _resample_numpy,
    _validate_resampling,
)
from ._merge import _merge_arrays_cupy, _merge_arrays_numpy, _validate_strategy
from ._transform import ApproximateTransform

from ._vertical import (
    geoid_height,
    geoid_height_raster,
    ellipsoidal_to_orthometric,
    orthometric_to_ellipsoidal,
    depth_to_ellipsoidal,
    ellipsoidal_to_depth,
)
from ._itrf import itrf_transform, list_frames as itrf_frames

__all__ = [
    'reproject', 'merge',
    'geoid_height', 'geoid_height_raster',
    'ellipsoidal_to_orthometric', 'orthometric_to_ellipsoidal',
    'depth_to_ellipsoidal', 'ellipsoidal_to_depth',
    'itrf_transform', 'itrf_frames',
]


# ---------------------------------------------------------------------------
# Source geometry helpers
# ---------------------------------------------------------------------------

_Y_NAMES = {'y', 'lat', 'latitude', 'Y', 'Lat', 'Latitude'}
_X_NAMES = {'x', 'lon', 'longitude', 'X', 'Lon', 'Longitude'}

# Map friendly vertical datum tokens to EPSG codes so attrs['vertical_crs']
# from reproject output matches the convention used by xrspatial.geotiff,
# which also writes EPSG ints to attrs['vertical_crs'].
_VERTICAL_DATUM_EPSG = {
    'EGM96': 5773,        # EGM96 height
    'EGM2008': 3855,      # EGM2008 height
    'ellipsoidal': 4979,  # WGS 84 (3D, ellipsoidal height)
}

# Sentinel marking the deprecated ``src_vertical_crs`` / ``tgt_vertical_crs``
# kwargs as "not passed". Distinct from None so we can tell an explicit
# ``src_vertical_crs=None`` apart from the default and only warn when the
# caller actually used the old name.
_DEPRECATED = object()


def _resolve_deprecated_vertical_kwarg(old_name, old_val, new_name, new_val):
    """Map a deprecated vertical-CRS kwarg onto its renamed replacement.

    Emits a ``DeprecationWarning`` when the old name is used and rejects
    passing both the old and new spellings at once.
    """
    if old_val is _DEPRECATED:
        return new_val
    import warnings
    warnings.warn(
        f"reproject(): {old_name!r} is deprecated, use {new_name!r} instead.",
        DeprecationWarning,
        stacklevel=3,
    )
    if new_val is not None:
        raise TypeError(
            f"reproject(): pass either {new_name!r} or the deprecated "
            f"{old_name!r}, not both."
        )
    return old_val


def _find_spatial_dims(raster):
    """Find the y and x dimension names, handling multi-band rasters.

    Returns (ydim, xdim).  Checks dim names first, falls back to
    assuming the last two non-band dims are spatial.
    """
    dims = raster.dims
    ydim = xdim = None
    for d in dims:
        if d in _Y_NAMES:
            ydim = d
        elif d in _X_NAMES:
            xdim = d
    if ydim is not None and xdim is not None:
        return ydim, xdim
    # Fallback: last two dims
    return dims[-2], dims[-1]


# Default tolerance for the regular-spacing check. Coordinates loaded from
# real GeoTIFFs can drift a few ULPs from perfectly uniform after pixel-to-
# world transforms, so 1e-6 (relative) is loose enough to accept those while
# still catching the single-pixel perturbation case in #2184.
_REGULAR_COORD_RTOL = 1e-6


def _validate_regular_axis(coords, axis_name, func_name, rtol=_REGULAR_COORD_RTOL):
    """Validate that 1-D coordinate array is strictly monotonic and regular.

    Pixel-resolution math in `_source_bounds` and the chunk workers assumes
    a uniform grid. Without this check, irregular or non-monotonic coords
    silently produce wrong georeferencing (see #2184).

    Parameters
    ----------
    coords : array-like
        1-D coordinate values along one axis.
    axis_name : str
        Name of the axis ("x" or "y") for the error message.
    func_name : str
        Calling function name for the error prefix.
    rtol : float
        Relative tolerance for spacing regularity.

    Raises
    ------
    ValueError
        If coords contain non-finite values, are not strictly monotonic,
        or have spacing that varies by more than ``rtol`` relative to the
        median step.
    """
    arr = np.asarray(coords)
    if arr.ndim != 1:
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' must be 1-D, "
            f"got shape {arr.shape}."
        )
    if arr.size < 2:
        # A single-pixel raster has no spacing to validate; the caller
        # will fall back to res=1.0 in _source_bounds, which is fine.
        return
    if not np.all(np.isfinite(arr)):
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' contains non-finite "
            f"values (NaN or inf)."
        )
    # np.asarray skips the copy when arr is already float64; np.diff promotes
    # ints to int64, which is fine but we want float steps for the median /
    # tolerance math below.
    diffs = np.diff(np.asarray(arr, dtype=np.float64))
    # Strict monotonicity: every step has the same sign and is non-zero.
    # `diffs > 0` AND `diffs < 0` are both False for zero steps (repeated
    # coords), so the combined check rejects them. Do NOT replace this with
    # a sign-only test like `np.all(np.sign(diffs) == np.sign(diffs[0]))` --
    # that variant accepts zero steps and lets a repeated coord through.
    if not (np.all(diffs > 0) or np.all(diffs < 0)):
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' must be strictly "
            f"monotonic (all ascending or all descending). The reproject "
            f"pipeline assumes a uniformly-spaced grid; see #2184."
        )
    median_step = float(np.median(diffs))
    abs_med = abs(median_step)
    deviation = np.abs(diffs - median_step)
    worst = float(np.max(deviation))
    if worst > rtol * abs_med:
        # Report the index of the worst step in the original coords so the
        # caller can locate the offending sample without re-running diff.
        worst_idx = int(np.argmax(deviation))
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' is not regularly "
            f"spaced. Median step is {median_step!r}; worst deviation is "
            f"{worst!r} at index {worst_idx} (between {axis_name}[{worst_idx}]"
            f"={float(arr[worst_idx])!r} and {axis_name}[{worst_idx + 1}]"
            f"={float(arr[worst_idx + 1])!r}). The reproject pipeline "
            f"assumes a uniformly-spaced grid; see #2184."
        )


def _validate_source_coords(raster, func_name):
    """Validate both spatial axes of a raster before any reproject work."""
    ydim, xdim = _find_spatial_dims(raster)
    _validate_regular_axis(raster.coords[ydim].values, 'y', func_name)
    _validate_regular_axis(raster.coords[xdim].values, 'x', func_name)


def _source_bounds(raster):
    """Extract (left, bottom, right, top) from a DataArray's coordinates."""
    ydim, xdim = _find_spatial_dims(raster)
    y = raster.coords[ydim].values
    x = raster.coords[xdim].values
    # Compute pixel-edge bounds from pixel-center coords
    if len(y) > 1:
        res_y = abs(float(y[1] - y[0]))
    else:
        res_y = 1.0
    if len(x) > 1:
        res_x = abs(float(x[1] - x[0]))
    else:
        res_x = 1.0
    x_min, x_max = float(np.min(x)), float(np.max(x))
    y_min, y_max = float(np.min(y)), float(np.max(y))
    left = x_min - res_x / 2
    right = x_max + res_x / 2
    bottom = y_min - res_y / 2
    top = y_max + res_y / 2
    return (left, bottom, right, top)


def _is_y_descending(raster):
    """Check if Y axis goes from top (large) to bottom (small)."""
    ydim, _ = _find_spatial_dims(raster)
    y = raster.coords[ydim].values
    if len(y) < 2:
        return True
    return float(y[0]) > float(y[-1])


def _is_x_descending(raster):
    """Check if X axis goes from right (large) to left (small).

    Mirrors :func:`_is_y_descending` for the horizontal axis. The default
    convention for a single-column raster is ascending x (matching
    :func:`_make_output_coords` which always emits ascending x).
    """
    _, xdim = _find_spatial_dims(raster)
    x = raster.coords[xdim].values
    if len(x) < 2:
        return False
    return float(x[0]) > float(x[-1])


# ---------------------------------------------------------------------------
# Per-chunk coordinate transform
# ---------------------------------------------------------------------------

def _transform_coords(transformer, chunk_bounds, chunk_shape,
                      transform_precision, src_crs=None, tgt_crs=None):
    """Compute source CRS coordinates for every output pixel.

    When *transform_precision* is 0, every pixel is transformed through
    pyproj exactly (same strategy as GDAL/rasterio).  Otherwise an
    approximate bilinear control-grid interpolation is used.

    For common CRS pairs (WGS84/NAD83 <-> UTM, WGS84 <-> Web Mercator),
    a Numba JIT fast path bypasses pyproj entirely for ~30x speedup.

    Returns
    -------
    src_y, src_x : ndarray (height, width)
    """
    # Try Numba fast path for common projections.
    # transform_precision == 0 is the documented escape hatch for exact
    # per-pixel pyproj transforms, so skip the approximate fast path then.
    if (transform_precision != 0
            and src_crs is not None and tgt_crs is not None):
        try:
            from ._projections import try_numba_transform
            result = try_numba_transform(
                src_crs, tgt_crs, chunk_bounds, chunk_shape,
            )
            if result is not None:
                return result
        except (ImportError, ModuleNotFoundError):
            pass  # fall through to pyproj

    height, width = chunk_shape
    left, bottom, right, top = chunk_bounds
    res_x = (right - left) / width
    res_y = (top - bottom) / height

    if transform_precision == 0:
        # Exact per-pixel transform via pyproj bulk API.
        # Process in row strips to keep memory bounded and improve
        # cache locality for large rasters.
        out_x_1d = left + (np.arange(width, dtype=np.float64) + 0.5) * res_x
        src_x_out = np.empty((height, width), dtype=np.float64)
        src_y_out = np.empty((height, width), dtype=np.float64)
        strip = 256
        for r0 in range(0, height, strip):
            r1 = min(r0 + strip, height)
            n_rows = r1 - r0
            out_y_strip = top - (np.arange(r0, r1, dtype=np.float64) + 0.5) * res_y
            # Broadcast to (n_rows, width) without allocating a full copy
            sx, sy = transformer.transform(
                np.tile(out_x_1d, n_rows),
                np.repeat(out_y_strip, width),
            )
            src_x_out[r0:r1] = np.asarray(sx, dtype=np.float64).reshape(n_rows, width)
            src_y_out[r0:r1] = np.asarray(sy, dtype=np.float64).reshape(n_rows, width)
        return src_y_out, src_x_out

    # Approximate: bilinear interpolation on a coarse control grid.
    approx = ApproximateTransform(
        transformer, chunk_bounds, chunk_shape,
        precision=transform_precision,
    )
    row_grid = np.arange(height, dtype=np.float64)[:, np.newaxis]
    col_grid = np.arange(width, dtype=np.float64)[np.newaxis, :]
    row_grid = np.broadcast_to(row_grid, (height, width))
    col_grid = np.broadcast_to(col_grid, (height, width))
    return approx(row_grid, col_grid)


# ---------------------------------------------------------------------------
# Per-chunk worker functions
# ---------------------------------------------------------------------------

def _reproject_chunk_numpy(
    source_data, source_bounds_tuple, source_shape, source_y_desc,
    src_wkt, tgt_wkt,
    chunk_bounds_tuple, chunk_shape,
    resampling, nodata, transform_precision,
    source_x_desc=False,
    band_nodata=None,
):
    """Reproject a single output chunk (numpy backend).

    Called inside ``dask.delayed`` for the dask path, or directly for numpy.
    CRS objects are passed as WKT strings for pickle safety.

    ``source_x_desc`` mirrors ``source_y_desc`` for the horizontal axis:
    when True, source column 0 is at the maximum x and column ``src_w-1``
    is at the minimum x. Defaults to False so older callers keep working.
    """
    from ._crs_utils import _crs_from_wkt

    src_crs = _crs_from_wkt(src_wkt)
    tgt_crs = _crs_from_wkt(tgt_wkt)

    # Try Numba fast path first (avoids creating pyproj Transformer).
    # transform_precision == 0 forces the exact pyproj path, so skip Numba.
    numba_result = None
    if transform_precision != 0:
        try:
            from ._projections import try_numba_transform
            numba_result = try_numba_transform(
                src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
            )
        except (ImportError, ModuleNotFoundError):
            pass

    if numba_result is not None:
        src_y, src_x = numba_result
    else:
        # Fallback: create pyproj Transformer (expensive)
        from ._crs_utils import _require_pyproj
        pyproj = _require_pyproj()
        transformer = pyproj.Transformer.from_crs(
            tgt_crs, src_crs, always_xy=True
        )
        src_y, src_x = _transform_coords(
            transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
            src_crs=src_crs, tgt_crs=tgt_crs,
        )

    # Convert source CRS coordinates to source pixel coordinates
    src_left, src_bottom, src_right, src_top = source_bounds_tuple
    src_h, src_w = source_shape
    src_res_x = (src_right - src_left) / src_w
    src_res_y = (src_top - src_bottom) / src_h

    if source_x_desc:
        src_col_px = (src_right - src_x) / src_res_x - 0.5
    else:
        src_col_px = (src_x - src_left) / src_res_x - 0.5
    if source_y_desc:
        src_row_px = (src_top - src_y) / src_res_y - 0.5
    else:
        src_row_px = (src_y - src_bottom) / src_res_y - 0.5

    # Determine source window needed
    r_min = np.nanmin(src_row_px)
    r_max = np.nanmax(src_row_px)
    c_min = np.nanmin(src_col_px)
    c_max = np.nanmax(src_col_px)

    # 3-D source: empty-chunk returns must carry the band axis or the
    # dask map_blocks template (which is 3-D for 3-D sources) sees a
    # shape mismatch (#2027).
    if source_data.ndim == 3:
        _empty_shape = (*chunk_shape, source_data.shape[2])
    else:
        _empty_shape = chunk_shape

    if not np.isfinite(r_min) or not np.isfinite(r_max):
        return np.full(_empty_shape, nodata, dtype=np.float64)
    if not np.isfinite(c_min) or not np.isfinite(c_max):
        return np.full(_empty_shape, nodata, dtype=np.float64)

    r_min = int(np.floor(r_min)) - 2
    r_max = int(np.ceil(r_max)) + 3
    c_min = int(np.floor(c_min)) - 2
    c_max = int(np.ceil(c_max)) + 3

    # Check overlap
    if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
        return np.full(_empty_shape, nodata, dtype=np.float64)

    # Clip to source bounds
    r_min_clip = max(0, r_min)
    r_max_clip = min(src_h, r_max)
    c_min_clip = max(0, c_min)
    c_max_clip = min(src_w, c_max)

    # Guard: cap source window to prevent OOM if projection maps a small
    # output chunk to a huge source area (e.g. polar stereographic edges).
    _MAX_WINDOW_PIXELS = 64 * 1024 * 1024  # 64 Mpix (~512 MB for float64)
    win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
    if win_pixels > _MAX_WINDOW_PIXELS:
        return np.full(_empty_shape, nodata, dtype=np.float64)

    # Extract source window
    window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
    if hasattr(window, 'compute'):
        window = window.compute()
    window = np.asarray(window)
    orig_dtype = window.dtype

    # Adjust coordinates relative to window
    local_row = src_row_px - r_min_clip
    local_col = src_col_px - c_min_clip

    # Multi-band: reproject each band separately, share coordinates
    if window.ndim == 3:
        n_bands = window.shape[2]
        bands = []
        for b in range(n_bands):
            band_data = window[:, :, b].astype(np.float64)
            # Mask this band with its own source sentinel when the raster
            # declares per-band nodata; otherwise fall back to the single
            # resolved sentinel (#2647).
            src_nd = band_nodata[b] if band_nodata is not None else nodata
            if not np.isnan(src_nd):
                band_data[band_data == src_nd] = np.nan
            band_result = _resample_numpy(band_data, local_row, local_col,
                                          resampling=resampling, nodata=nodata)
            if np.issubdtype(orig_dtype, np.integer):
                info = np.iinfo(orig_dtype)
                band_result = np.clip(np.round(band_result), info.min, info.max).astype(orig_dtype)
            bands.append(band_result)
        return np.stack(bands, axis=-1)

    # Single-band path
    window = window.astype(np.float64)

    # Convert sentinel nodata to NaN so numba kernels can detect it
    if not np.isnan(nodata):
        window[window == nodata] = np.nan

    result = _resample_numpy(window, local_row, local_col,
                             resampling=resampling, nodata=nodata)

    # Clamp and cast back for integer source dtypes
    if np.issubdtype(orig_dtype, np.integer):
        info = np.iinfo(orig_dtype)
        result = np.clip(np.round(result), info.min, info.max).astype(orig_dtype)

    return result


def _reproject_chunk_cupy(
    source_data, source_bounds_tuple, source_shape, source_y_desc,
    src_wkt, tgt_wkt,
    chunk_bounds_tuple, chunk_shape,
    resampling, nodata, transform_precision,
    source_x_desc=False,
    band_nodata=None,
):
    """CuPy variant of ``_reproject_chunk_numpy``.

    ``source_x_desc`` carries the horizontal direction flag (same meaning
    as in :func:`_reproject_chunk_numpy`).
    """
    import cupy as cp

    from ._crs_utils import _crs_from_wkt

    src_crs = _crs_from_wkt(src_wkt)
    tgt_crs = _crs_from_wkt(tgt_wkt)

    # 3-D source: empty-chunk returns must carry the band axis to match
    # the dask+cupy map_blocks template (#2027).
    if source_data.ndim == 3:
        _empty_shape = (*chunk_shape, source_data.shape[2])
    else:
        _empty_shape = chunk_shape

    # Try CUDA transform first (keeps coordinates on-device).
    # transform_precision == 0 forces the exact pyproj path, so skip CUDA.
    cuda_result = None
    if (transform_precision != 0
            and src_crs is not None and tgt_crs is not None):
        try:
            from ._projections_cuda import try_cuda_transform
            cuda_result = try_cuda_transform(
                src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
            )
        except (ImportError, ModuleNotFoundError):
            pass

    if cuda_result is not None:
        src_y, src_x = cuda_result  # cupy arrays
        src_left, src_bottom, src_right, src_top = source_bounds_tuple
        src_h, src_w = source_shape
        src_res_x = (src_right - src_left) / src_w
        src_res_y = (src_top - src_bottom) / src_h
        # Pixel coordinate math stays on GPU via cupy operators
        if source_x_desc:
            src_col_px = (src_right - src_x) / src_res_x - 0.5
        else:
            src_col_px = (src_x - src_left) / src_res_x - 0.5
        if source_y_desc:
            src_row_px = (src_top - src_y) / src_res_y - 0.5
        else:
            src_row_px = (src_y - src_bottom) / src_res_y - 0.5
        # Need min/max on CPU for window selection.
        # Stack the four reductions and pull them across in one device-to-host
        # transfer to avoid four separate synchronous syncs.
        mins_maxes = cp.stack([
            cp.nanmin(src_row_px), cp.nanmax(src_row_px),
            cp.nanmin(src_col_px), cp.nanmax(src_col_px),
        ])
        r_min_val, r_max_val, c_min_val, c_max_val = (
            float(v) for v in mins_maxes.get()
        )
        if not (np.isfinite(r_min_val) and np.isfinite(r_max_val)
                and np.isfinite(c_min_val) and np.isfinite(c_max_val)):
            return cp.full(_empty_shape, nodata, dtype=cp.float64)
        r_min = int(np.floor(r_min_val)) - 2
        r_max = int(np.ceil(r_max_val)) + 3
        c_min = int(np.floor(c_min_val)) - 2
        c_max = int(np.ceil(c_max_val)) + 3
        # Keep coordinates as CuPy arrays for native CUDA resampling
        _use_native_cuda = True
    else:
        # CPU fallback (Numba JIT or pyproj)
        from ._crs_utils import _require_pyproj
        pyproj = _require_pyproj()
        transformer = pyproj.Transformer.from_crs(
            tgt_crs, src_crs, always_xy=True
        )
        src_y, src_x = _transform_coords(
            transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
            src_crs=src_crs, tgt_crs=tgt_crs,
        )

        src_left, src_bottom, src_right, src_top = source_bounds_tuple
        src_h, src_w = source_shape
        src_res_x = (src_right - src_left) / src_w
        src_res_y = (src_top - src_bottom) / src_h

        if source_x_desc:
            src_col_px = (src_right - src_x) / src_res_x - 0.5
        else:
            src_col_px = (src_x - src_left) / src_res_x - 0.5
        if source_y_desc:
            src_row_px = (src_top - src_y) / src_res_y - 0.5
        else:
            src_row_px = (src_y - src_bottom) / src_res_y - 0.5

        r_min = np.nanmin(src_row_px)
        r_max = np.nanmax(src_row_px)
        c_min = np.nanmin(src_col_px)
        c_max = np.nanmax(src_col_px)
        if not np.isfinite(r_min) or not np.isfinite(r_max):
            return cp.full(_empty_shape, nodata, dtype=cp.float64)
        if not np.isfinite(c_min) or not np.isfinite(c_max):
            return cp.full(_empty_shape, nodata, dtype=cp.float64)
        r_min = int(np.floor(r_min)) - 2
        r_max = int(np.ceil(r_max)) + 3
        c_min = int(np.floor(c_min)) - 2
        c_max = int(np.ceil(c_max)) + 3
        _use_native_cuda = False

    if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
        return cp.full(_empty_shape, nodata, dtype=cp.float64)

    r_min_clip = max(0, r_min)
    r_max_clip = min(src_h, r_max)
    c_min_clip = max(0, c_min)
    c_max_clip = min(src_w, c_max)

    # Guard: cap source window to prevent GPU OOM if projection maps a
    # small output chunk to a huge source area (matches numpy path).
    _MAX_WINDOW_PIXELS = 64 * 1024 * 1024  # 64 Mpix (~512 MB for float64)
    win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
    if win_pixels > _MAX_WINDOW_PIXELS:
        return cp.full(_empty_shape, nodata, dtype=cp.float64)

    window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
    if hasattr(window, 'compute'):
        window = window.compute()
    if not isinstance(window, cp.ndarray):
        window = cp.asarray(window)
    orig_dtype = window.dtype

    # Adjust coordinates relative to window (stays on GPU if CuPy)
    local_row = src_row_px - r_min_clip
    local_col = src_col_px - c_min_clip

    # Multi-band: reproject each band separately, share coordinates.
    # Matches the 3-D branch in _reproject_chunk_numpy so 3-D inputs work
    # on cupy and dask+cupy backends instead of crashing with a CUDA
    # signature mismatch (#2027).
    if window.ndim == 3:
        n_bands = window.shape[2]
        bands = []
        for b in range(n_bands):
            band_data = window[:, :, b].astype(cp.float64)
            # Mask this band with its own source sentinel when the raster
            # declares per-band nodata; otherwise fall back to the single
            # resolved sentinel (#2647). Pre-converting to NaN here lets
            # each band use a different source sentinel; the native kernel
            # still fills out-of-bounds pixels with the resolved `nodata`.
            src_nd = band_nodata[b] if band_nodata is not None else nodata
            if not np.isnan(src_nd):
                band_data = cp.where(
                    band_data == src_nd, cp.nan, band_data,
                )
            # Always resample through the native CUDA kernels so the cupy
            # backend matches numpy exactly. They accept CPU coordinate
            # arrays (transferring them to the GPU) and do the
            # nodata->NaN conversion internally, so they serve both the
            # on-device coordinate path and the pyproj fallback. Using
            # cupyx.scipy.ndimage.map_coordinates here instead would
            # diverge from numpy: it bleeds the cval=0.0 constant into the
            # half-pixel boundary band rather than renormalizing, and its
            # order=3 path is a B-spline rather than Catmull-Rom (#2620).
            band_result = _resample_cupy_native(
                band_data, local_row, local_col,
                resampling=resampling, nodata=nodata,
            )
            if np.issubdtype(orig_dtype, np.integer):
                info = np.iinfo(orig_dtype)
                band_result = cp.clip(
                    cp.round(band_result), info.min, info.max,
                ).astype(orig_dtype)
            bands.append(band_result)
        return cp.stack(bands, axis=-1)

    window = window.astype(cp.float64)

    # Always resample through the native CUDA kernels for numpy parity.
    # local_row/local_col may be CuPy (on-device transform) or numpy
    # (pyproj fallback); _resample_cupy_native handles both and does the
    # nodata->NaN conversion internally. The previous
    # cupyx.scipy.ndimage.map_coordinates fallback diverged from numpy at
    # chunk edges and for cubic resampling (#2620).
    result = _resample_cupy_native(window, local_row, local_col,
                                   resampling=resampling, nodata=nodata)

    # Clamp and cast back for integer source dtypes (parity with numpy path)
    if np.issubdtype(orig_dtype, np.integer):
        info = np.iinfo(orig_dtype)
        result = cp.clip(cp.round(result), info.min, info.max).astype(orig_dtype)
    return result


# ---------------------------------------------------------------------------
# reproject()
# ---------------------------------------------------------------------------

[docs] def reproject( raster, target_crs, *, source_crs=None, resolution=None, bounds=None, width=None, height=None, resampling='bilinear', nodata=None, transform_precision=16, chunk_size=None, name=None, max_memory=None, source_vertical_crs=None, target_vertical_crs=None, bounds_policy="auto", src_vertical_crs=_DEPRECATED, tgt_vertical_crs=_DEPRECATED, ): """Reproject a raster DataArray to a new coordinate reference system. Supports numpy, cupy, dask+numpy, and dask+cupy backends. For dask inputs, the computation is fully lazy: each output chunk independently reads only the source pixels it needs. Parameters ---------- raster : xr.DataArray Input raster with y/x coordinates. target_crs Target CRS in any format accepted by ``pyproj.CRS()``. source_crs : optional Source CRS. Auto-detected from *raster* if None. resolution : float or (float, float) or None Output pixel size in target CRS units. bounds : (left, bottom, right, top) or None Explicit output extent in target CRS. width, height : int or None Explicit output grid dimensions. resampling : str One of 'nearest', 'bilinear', 'cubic'. nodata : float or None Nodata value. Auto-detected if None. For integer input dtypes, an explicit value that does not fit the dtype range raises ``ValueError`` (e.g. ``nodata=-9999`` with a ``uint8`` raster). Attrs/rioxarray-derived out-of-range values emit a ``UserWarning`` and fall back to ``dtype.min`` for signed or ``dtype.max`` for unsigned so legacy files still load (#2572). transform_precision : int Control-grid subdivisions for the coordinate transform (default 16). Higher values increase accuracy at the cost of more pyproj calls. Set to 0 for exact per-pixel transforms matching GDAL/rasterio. chunk_size : int or (int, int) or None Output chunk size for dask. If None, defaults to 512 for the standard dask path and 2048 for the in-memory streaming and dask+cupy paths (chosen to amortize kernel launch overhead). name : str or None Name for the output DataArray. max_memory : int or str or None Maximum memory budget for the reprojection working set. Accepts bytes (int) or human-readable strings like ``'4GB'``, ``'512MB'``. Controls how many output tiles are processed in parallel for large-dataset streaming mode. Default None uses 1GB. Has no effect for small datasets that fit in memory. source_vertical_crs : str or None Source vertical datum for height values. One of: - ``'EGM96'`` -- orthometric heights relative to EGM96 geoid (MSL) - ``'EGM2008'`` -- orthometric heights relative to EGM2008 geoid - ``'ellipsoidal'`` -- heights relative to the WGS84 ellipsoid - ``None`` -- no vertical transformation (default) target_vertical_crs : str or None Target vertical datum. Same options as *source_vertical_crs*. Both must be set to trigger a vertical transformation. src_vertical_crs : str or None Deprecated alias for *source_vertical_crs*. Passing it emits a ``DeprecationWarning``. tgt_vertical_crs : str or None Deprecated alias for *target_vertical_crs*. Passing it emits a ``DeprecationWarning``. bounds_policy : {"auto", "raw", "clamp", "percentile"}, default "auto" How to derive the output extent from the source extent when ``bounds`` is not supplied. Only relevant when projecting near a singularity (antimeridian, pole, projection edge): - ``"raw"``: use the true projected extent of the source corners and edges. No clamp, no percentile, no heuristic. The output may be very large if the input straddles a projection singularity. Use this when you want a true projection of the source extent. - ``"clamp"``: trim geographic source bounds inward by 0.01 deg from +/-180 longitude and +/-90 latitude before projecting. Avoids infinities at singularities. No percentile fallback. No-op on projected source CRSes (UTM, Mercator, etc.) since the clamp only applies in degrees. - ``"percentile"``: project a dense interior grid of the source extent and use the 2nd/98th percentiles of the result as the output bounds. Rejects projection outliers at the cost of trimming valid pixels. - ``"auto"`` (default): apply ``"clamp"`` for geographic source CRSes and fall back to ``"percentile"`` when the projected extent is more than 50x the source extent. Matches the historical behaviour. When ``"auto"``, ``"clamp"``, or ``"percentile"`` actually alters the bounds, a ``UserWarning`` is emitted naming the policy and reporting the per-side delta versus the raw projected bounds. Filter with ``warnings.filterwarnings`` if the crop is intentional. Returns ------- xr.DataArray The output ``attrs['crs']`` is in WKT format. Whenever *target_vertical_crs* is set, ``attrs['vertical_crs']`` records the target vertical datum's EPSG code (5773 for EGM96, 3855 for EGM2008, 4979 for ellipsoidal WGS84) to match the convention used by ``xrspatial.geotiff``. The friendly string token (``'EGM96'`` etc.) is preserved under ``attrs['vertical_datum']``. Both attrs are written even when no shift is applied (e.g. when *source_vertical_crs* equals *target_vertical_crs*, or when only the target is given), so the output's vertical reference is always explicit. The output y coordinate is always emitted in descending order (top-down, north-up) and the output x coordinate is always emitted in ascending order (left-to-right) regardless of the input directions. This matches the standard raster convention and the output of common GIS libraries. Inputs with descending x are detected from the x coordinate values and handled the same way as descending y: the pixel-index mapping is mirrored so the output values stay correct. Non-spatial coords from the input (such as a scalar ``time`` coord or a non-dimension coord that is not aligned to the spatial dims) are carried through to the output. Coords that are aligned to the input y or x dims are dropped because their values do not apply to the rebuilt grid. Examples -------- >>> import xarray as xr >>> import numpy as np >>> from xrspatial.reproject import reproject >>> raster = xr.DataArray( ... np.random.rand(64, 64), ... dims=['y', 'x'], ... coords={'y': np.linspace(50, 40, 64), ... 'x': np.linspace(-5, 5, 64)}, ... attrs={'crs': 'EPSG:4326'}, ... ) >>> result = reproject(raster, 'EPSG:3857') >>> result.attrs['crs'].startswith(('PROJCRS', 'PROJCS')) True """ # Back-compat shim for the old abbreviated kwarg names. These were # renamed to source_vertical_crs / target_vertical_crs to match the # source_crs / target_crs spelling used by the rest of the signature. source_vertical_crs = _resolve_deprecated_vertical_kwarg( 'src_vertical_crs', src_vertical_crs, 'source_vertical_crs', source_vertical_crs) target_vertical_crs = _resolve_deprecated_vertical_kwarg( 'tgt_vertical_crs', tgt_vertical_crs, 'target_vertical_crs', target_vertical_crs) _validate_raster(raster, func_name='reproject', name='raster', ndim=(2, 3)) # Reject irregular / non-monotonic source coords before any CRS # resolution or grid math. _source_bounds() infers pixel size from # only the first two coord samples and downstream pixel math assumes # uniform spacing, so an unchecked irregular input would silently # produce wrong georeferencing (#2184). _validate_source_coords(raster, 'reproject') _validate_grid_params( resolution=resolution, bounds=bounds, width=width, height=height, transform_precision=transform_precision, func_name='reproject', ) _validate_resampling(resampling) from ._grid import _validate_bounds_policy _validate_bounds_policy(bounds_policy, func_name='reproject') # Reject unknown vertical-datum tokens at the API boundary so we never # write None into attrs['vertical_crs'] for typos / unsupported values. for _name, _val in (('source_vertical_crs', source_vertical_crs), ('target_vertical_crs', target_vertical_crs)): if _val is not None and _val not in _VERTICAL_DATUM_EPSG: raise ValueError( f"Unknown {_name}={_val!r}; expected one of " f"{sorted(_VERTICAL_DATUM_EPSG)} or None." ) # Normalize 3-D inputs to canonical (y, x, band) layout. # The per-chunk workers slice the source as ``source_data[r:, c:]`` and # assume the band axis is trailing. A rasterio/rioxarray-style # ``(band, y, x)`` input would otherwise slice the band/y axes instead # of the y/x axes and either crash or return wrong-shape data (#2182). # We record the input's original dim order so the output can be # transposed back at the end, preserving downstream expectations. _input_dims = tuple(raster.dims) if raster.ndim == 3: _ydim_in, _xdim_in = _find_spatial_dims(raster) _band_dims_in = [d for d in _input_dims if d not in (_ydim_in, _xdim_in)] _band_dim_in = _band_dims_in[0] if _band_dims_in else None if _band_dim_in is not None: _canonical = (_ydim_in, _xdim_in, _band_dim_in) if _input_dims != _canonical: raster = raster.transpose(*_canonical) # Resolve CRS src_crs = _resolve_crs(source_crs) if src_crs is None: src_crs = _detect_source_crs(raster) if src_crs is None: raise ValueError( "Could not detect source CRS. Pass source_crs explicitly." ) tgt_crs = _resolve_crs(target_crs) # Detect nodata. Pass the raster dtype so integer rasters get an # integer-compatible sentinel (dtype min for signed, dtype max for # unsigned) instead of NaN. Without this hint, the worker's # cast-back step would collapse NaN to 0 and `attrs['nodata']` # would contradict the array contents (#2185). nd = _detect_nodata(raster, nodata, dtype=raster.dtype) # Multi-band rasters can declare a distinct source sentinel per band # via the rasterio `nodatavals` tuple. `nd` is the single resolved # output sentinel; `band_nd` carries the raw per-band source sentinels # so each band is masked with its own value before resampling (#2647). # `None` means one scalar covers every band -- the workers use `nd`. # The raster is in canonical (y, x, band) layout here, so the band # axis is trailing. _n_bands = raster.shape[2] if raster.ndim == 3 else None band_nd = _detect_band_nodata(raster, nodata, _n_bands) # Source geometry src_bounds = _source_bounds(raster) _ydim, _xdim = _find_spatial_dims(raster) src_shape = (raster.sizes[_ydim], raster.sizes[_xdim]) y_desc = _is_y_descending(raster) x_desc = _is_x_descending(raster) # Compute output grid grid = _compute_output_grid( src_bounds, src_shape, src_crs, tgt_crs, resolution=resolution, bounds=bounds, width=width, height=height, bounds_policy=bounds_policy, ) out_bounds = grid['bounds'] out_shape = grid['shape'] # Output coordinates y_coords, x_coords = _make_output_coords(out_bounds, out_shape) # Detect backend from ..utils import has_dask_array, is_cupy_array data = raster.data is_dask = False if has_dask_array(): import dask.array as _da is_dask = isinstance(data, _da.Array) is_cupy = False if is_dask: # Check underlying type try: from ..utils import is_cupy_backed is_cupy = is_cupy_backed(raster) except (ImportError, ModuleNotFoundError): pass else: is_cupy = is_cupy_array(data) # For large in-memory datasets, wrap in dask for chunked processing. # map_blocks generates an O(1) HighLevelGraph (single blockwise layer) # so graph metadata is no longer a concern -- the streaming fallback # is only needed when dask itself is unavailable. _use_streaming = False if not is_dask and not is_cupy: nbytes = src_shape[0] * src_shape[1] * data.dtype.itemsize if data.ndim == 3: nbytes *= data.shape[2] _OOM_THRESHOLD = 512 * 1024 * 1024 # 512 MB if nbytes > _OOM_THRESHOLD: cs = chunk_size or 2048 if isinstance(cs, int): cs = (cs, cs) try: import dask.array as _da data = _da.from_array(data, chunks=cs) raster = xr.DataArray( data, dims=raster.dims, coords=raster.coords, name=raster.name, attrs=raster.attrs, ) is_dask = True except ImportError: # dask not available -- fall back to streaming _use_streaming = True # Serialize CRS for pickle safety src_wkt = src_crs.to_wkt() tgt_wkt = tgt_crs.to_wkt() if _use_streaming: result_data = _reproject_streaming( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, chunk_size or 2048, _parse_max_memory(max_memory), x_desc=x_desc, band_nodata=band_nd, ) elif is_dask and is_cupy: result_data = _reproject_dask_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, chunk_size, x_desc=x_desc, band_nodata=band_nd, ) elif is_dask: result_data = _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, chunk_size, False, x_desc=x_desc, band_nodata=band_nd, ) elif is_cupy: result_data = _reproject_inmemory_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, x_desc=x_desc, band_nodata=band_nd, ) else: result_data = _reproject_inmemory_numpy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, x_desc=x_desc, band_nodata=band_nd, ) # Vertical datum transformation (if requested) if source_vertical_crs is not None and target_vertical_crs is not None: if source_vertical_crs != target_vertical_crs: result_data, nd = _apply_vertical_shift( result_data, y_coords, x_coords, source_vertical_crs, target_vertical_crs, nd, tgt_crs_wkt=tgt_wkt, ) ydim, xdim = _find_spatial_dims(raster) # Carry input attrs forward so units, long_name, scale_factor, etc. # survive the transform. Re-emit `transform` and `res` for the new # output grid (rasterio 6-tuple convention). Drop `crs_wkt` since it # would duplicate (or contradict) the canonical `crs` we re-emit. out_left, _out_bottom, _out_right, out_top = grid['bounds'] out_res_x = grid['res_x'] out_res_y = grid['res_y'] out_attrs = {**raster.attrs} out_attrs.pop('crs_wkt', None) out_attrs['crs'] = tgt_wkt out_attrs['nodata'] = nd out_attrs['res'] = (out_res_x, out_res_y) out_attrs['transform'] = ( out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top, ) # If the input used `_FillValue`, propagate the resolved nodata to # both keys so downstream consumers reading either find a consistent # value. If `_FillValue` was absent, leave it absent. if '_FillValue' in raster.attrs: out_attrs['_FillValue'] = nd # `nodatavals` (rasterio convention) is a tuple of per-band sentinels. # Refresh it to the resolved nodata so it doesn't contradict # ``out_attrs['nodata']`` after the resample. if 'nodatavals' in raster.attrs: old_nv = raster.attrs['nodatavals'] try: n_entries = max(1, len(old_nv)) except TypeError: n_entries = 1 out_attrs['nodatavals'] = tuple(nd for _ in range(n_entries)) if target_vertical_crs is not None: # Align with xrspatial.geotiff: attrs['vertical_crs'] holds the # EPSG integer code. The friendly string token is preserved under # attrs['vertical_datum'] so the human-readable name is not lost. # See GH issue #1570. out_attrs['vertical_crs'] = _VERTICAL_DATUM_EPSG.get(target_vertical_crs) out_attrs['vertical_datum'] = target_vertical_crs # Handle multi-band output (3D result from multi-band source) if result_data.ndim == 3: # Find the band dimension name from the source band_dims = [d for d in raster.dims if d not in (ydim, xdim)] band_dim = band_dims[0] if band_dims else 'band' out_dims = [ydim, xdim, band_dim] out_coords = {ydim: y_coords, xdim: x_coords} if band_dim in raster.coords: out_coords[band_dim] = raster.coords[band_dim] # Carry forward non-spatial coords (e.g. scalar 'time' coord). # Skip coords aligned to the rebuilt spatial dims because their # values do not apply to the new grid. for cname, cval in raster.coords.items(): if cname in (ydim, xdim, band_dim): continue if ydim in cval.dims or xdim in cval.dims: continue out_coords[cname] = cval else: out_dims = [ydim, xdim] out_coords = {ydim: y_coords, xdim: x_coords} # Carry forward non-spatial coords (e.g. scalar 'time' coord). # Skip coords aligned to the rebuilt spatial dims because their # values do not apply to the new grid. for cname, cval in raster.coords.items(): if cname in (ydim, xdim): continue if ydim in cval.dims or xdim in cval.dims: continue out_coords[cname] = cval result = xr.DataArray( result_data, dims=out_dims, coords=out_coords, name=name or raster.name, attrs=out_attrs, ) # Preserve the input's dim order so a ``(band, y, x)`` source produces a # ``(band, y, x)`` output (#2182). The internal pipeline always builds the # array as ``(y, x, band)`` for 3-D rasters; transpose back here. if result.ndim == 3 and set(_input_dims) == set(result.dims): if tuple(result.dims) != _input_dims: result = result.transpose(*_input_dims) return result
def _promoted_vertical_dtype(src_dtype): """Return the float dtype to use when applying a geoid shift. The geoid offset is fractional metres, so any integer input has to be promoted before we can add it. Returns ``None`` when the input is already a float dtype (including ``float64``) -- we leave its precision alone. Returns ``np.float32`` for any non-float input; that gives ~0.1 mm resolution in metres at altitudes up to ~10 km, which is well below the accuracy of the geoid grids themselves. """ src_dtype = np.dtype(src_dtype) if np.issubdtype(src_dtype, np.floating): return None return np.float32 def _apply_vertical_shift(data, y_coords, x_coords, src_vcrs, tgt_vcrs, nodata, tgt_crs_wkt=None): """Apply vertical datum shift to reprojected height values. The geoid undulation grid is in geographic (lon/lat) coordinates. If the output CRS is projected, coordinates are inverse-projected to geographic before the geoid lookup. Supported vertical CRS: - 'EGM96', 'EGM2008': orthometric heights (above geoid/MSL) - 'ellipsoidal': heights above WGS84 ellipsoid Backend handling ---------------- The geoid undulation lookup is CPU-only (Numba JIT). To stay correct across backends: - For ``cupy`` arrays, the input is brought to host, shifted, and moved back to GPU. - For ``dask`` arrays, the shift is wrapped in ``map_blocks`` so each chunk's slab is materialised, shifted, and returned in place. - For 3-D ``(y, x, band)`` results, the shift is applied per band because the geoid undulation depends only on horizontal position. Returns ------- (shifted, out_nodata) : tuple ``shifted`` is the height-shifted array, possibly promoted to a float dtype if the input was integer (see ``_promoted_vertical_dtype``). ``out_nodata`` is the nodata sentinel that matches the returned dtype -- ``NaN`` whenever a promotion happened so the caller can rely on NaN semantics in the float output. """ # Direction (sign convention) is the same regardless of backend. geoid_models = [] signs = [] if src_vcrs in ('EGM96', 'EGM2008') and tgt_vcrs == 'ellipsoidal': geoid_models.append(src_vcrs) signs.append(1.0) # H + N = h elif src_vcrs == 'ellipsoidal' and tgt_vcrs in ('EGM96', 'EGM2008'): geoid_models.append(tgt_vcrs) signs.append(-1.0) # h - N = H elif src_vcrs in ('EGM96', 'EGM2008') and tgt_vcrs in ('EGM96', 'EGM2008'): geoid_models.extend([src_vcrs, tgt_vcrs]) signs.extend([1.0, -1.0]) # H1 + N1 - N2 else: return data, nodata x_arr = np.asarray(x_coords, dtype=np.float64) y_arr = np.asarray(y_coords, dtype=np.float64) # Decide the output dtype once, here, so every backend agrees and # the caller can update attrs['nodata'] etc. with the post-shift # sentinel. promoted = _promoted_vertical_dtype(data.dtype) if promoted is None: out_dtype = np.dtype(data.dtype) out_nodata = nodata else: out_dtype = np.dtype(promoted) # NaN is the natural sentinel in a float output. The numpy path # below replaces any cells that matched the input integer # ``nodata`` with NaN before applying the shift. out_nodata = float('nan') # Dask backend: wrap the per-block computation. Each block sees its # own row slab of x/y coords, so we route through map_blocks and # delegate to the numpy path on every chunk. try: import dask.array as da is_dask = isinstance(data, da.Array) except ImportError: is_dask = False if is_dask: shifted = _apply_vertical_shift_dask( data, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) return shifted, out_nodata # CuPy backend: round-trip via host. The geoid lookup is small # relative to the reprojection itself, and the CPU JIT path is # already well-tested. try: import cupy as cp is_cupy = isinstance(data, cp.ndarray) except ImportError: is_cupy = False if is_cupy: host = cp.asnumpy(data) shifted = _apply_vertical_shift_numpy( host, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) return cp.asarray(shifted), out_nodata shifted = _apply_vertical_shift_numpy( np.asarray(data), y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) return shifted, out_nodata def _apply_vertical_shift_numpy(data, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=None): """Apply geoid shift on a numpy array. Handles both 2-D ``(H, W)`` and 3-D ``(H, W, B)`` shapes by looping over band slices; the geoid undulation depends only on horizontal position, so each band sees the same N(y, x) correction. If ``out_dtype`` is given and differs from ``data.dtype`` (the integer-DEM case), the input is cast to that float dtype before the shift and the integer ``nodata`` sentinel is replaced with NaN in the promoted output. When ``out_dtype`` is ``None`` (or matches the input), the behaviour is in-place on a copy of the input -- same as before this signature was added. """ from ._vertical import _load_geoid, _interp_geoid_2d # Determine if we need inverse projection (output CRS is projected) need_inverse = False transformer = None if tgt_crs_wkt is not None: try: from ._crs_utils import _require_pyproj pyproj = _require_pyproj() tgt_crs = pyproj.CRS.from_wkt(tgt_crs_wkt) if not tgt_crs.is_geographic: need_inverse = True geo_crs = pyproj.CRS.from_epsg(4326) transformer = pyproj.Transformer.from_crs( tgt_crs, geo_crs, always_xy=True ) except Exception: pass assert data.ndim in (2, 3), f"expected 2-D or 3-D, got {data.ndim}-D" out_h, out_w = data.shape[:2] is_3d = data.ndim == 3 geoids = [_load_geoid(gm) for gm in geoid_models] # Process in row strips to bound memory in the numpy path; dask chunks # are usually smaller than one strip so this loop runs once per block. # If the caller asked for dtype promotion (integer DEM -> float), # build ``result`` in the new dtype and rewrite the input nodata # sentinel to NaN so the rest of the loop can treat NaN as the # missing-value marker uniformly. promoted = ( out_dtype is not None and np.dtype(out_dtype) != np.dtype(data.dtype) ) if promoted: result = data.astype(out_dtype, copy=True) # Map the source nodata (e.g. -32768 for int16) onto NaN before # shifting so downstream consumers can use NaN semantics. if isinstance(nodata, float) and np.isnan(nodata): pass # already NaN else: result[data == nodata] = np.nan is_nan_nodata = True else: result = data.copy() is_nan_nodata = np.isnan(nodata) if isinstance(nodata, float) else False strip = 128 for r0 in range(0, out_h, strip): r1 = min(r0 + strip, out_h) n_rows = r1 - r0 # Build strip coordinate grid xx_strip = np.tile(x_arr, n_rows).reshape(n_rows, out_w) yy_strip = np.repeat(y_arr[r0:r1], out_w).reshape(n_rows, out_w) # Inverse project if needed if need_inverse and transformer is not None: lon_s, lat_s = transformer.transform(xx_strip.ravel(), yy_strip.ravel()) xx_strip = np.asarray(lon_s, dtype=np.float64).reshape(n_rows, out_w) yy_strip = np.asarray(lat_s, dtype=np.float64).reshape(n_rows, out_w) # Guard against non-finite output coords (projection singularities, # antimeridian, polar regions). Hand NaN to the JIT batch so the # longitude wrap loop in _interp_geoid_point does not see inf and # spin forever. finite_coord = np.isfinite(xx_strip) & np.isfinite(yy_strip) if not finite_coord.all(): xx_strip = np.where(finite_coord, xx_strip, np.nan) yy_strip = np.where(finite_coord, yy_strip, np.nan) # Compute the total shift N_total for this strip once, regardless # of how many geoid models contribute. N_total = np.zeros((n_rows, out_w), dtype=np.float64) for (grid_data, g_left, g_top, g_rx, g_ry, g_h, g_w), sign in zip(geoids, signs): N_strip = np.empty((n_rows, out_w), dtype=np.float64) _interp_geoid_2d(xx_strip, yy_strip, N_strip, grid_data, g_left, g_top, g_rx, g_ry, g_h, g_w) N_total += sign * N_strip # Apply to each band slice. For 2-D this loop runs once. n_bands = data.shape[2] if is_3d else 1 for b in range(n_bands): if is_3d: strip_data = result[r0:r1, :, b] else: strip_data = result[r0:r1] if is_nan_nodata: is_valid = np.isfinite(strip_data) else: is_valid = strip_data != nodata is_valid = is_valid & finite_coord strip_data[is_valid] += N_total[is_valid] return result def _apply_vertical_shift_dask(data, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=None): """Dask-backed geoid shift via ``map_blocks``. Each block receives only its row/column slab of the input. The block function recomputes per-block y/x slices from ``block_info`` and delegates to the numpy path so all backends share one implementation. ``out_dtype`` is the (possibly promoted) dtype that ``_apply_vertical_shift`` has decided for the whole array, and is forwarded into every per-block call so the chunks return the same promoted dtype the dask graph metadata advertises. """ import dask.array as da if out_dtype is None: out_dtype = data.dtype # Note: blocks reaching this path are assumed to be numpy-backed. # dask-of-cupy is intentionally unreached today because # ``_reproject_dask(is_cupy=True)`` collapses to a numpy-backed dask # array upstream. If that ever changes, _block would need to detect # cupy chunks and host-bounce per chunk. def _block(block, block_info): info = block_info[0] (r0, r1), (c0, c1) = info['array-location'][:2] y_slab = y_arr[r0:r1] x_slab = x_arr[c0:c1] return _apply_vertical_shift_numpy( block, y_slab, x_slab, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) # ``meta`` hardcodes a numpy template to match the assumption above # that incoming chunks are numpy-backed. Revisit if dask-of-cupy is # ever plumbed through. return da.map_blocks(_block, data, dtype=out_dtype, meta=np.array((), dtype=out_dtype)) def _reproject_inmemory_numpy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, x_desc=False, band_nodata=None, ): """Single-chunk numpy reproject.""" return _reproject_chunk_numpy( raster.values, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) def _reproject_inmemory_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, x_desc=False, band_nodata=None, ): """Single-chunk cupy reproject.""" return _reproject_chunk_cupy( raster.data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) def _parse_max_memory(max_memory): """Parse max_memory parameter to bytes. Accepts int, '4GB', '512MB'.""" if max_memory is None: return 1024 * 1024 * 1024 # 1GB default if isinstance(max_memory, (int, float)): return int(max_memory) s = str(max_memory).strip().upper() for suffix, factor in [('TB', 1024**4), ('GB', 1024**3), ('MB', 1024**2), ('KB', 1024)]: if s.endswith(suffix): return int(float(s[:-len(suffix)]) * factor) return int(s) def _process_tile_batch(batch, source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, resampling, nodata, precision, max_memory_bytes, tile_mem, x_desc=False, band_nodata=None): """Process a batch of tiles within a single worker. Uses ThreadPoolExecutor for intra-worker parallelism (Numba releases the GIL). Memory bounded by max_memory_bytes. Returns list of (row_offset, col_offset, tile_data) tuples. """ max_concurrent = max(1, max_memory_bytes // max(tile_mem, 1)) def _do_one(job): _, _, rchunk, cchunk, cb = job return _reproject_chunk_numpy( source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, cb, (rchunk, cchunk), resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) results = [] if max_concurrent >= 2 and len(batch) > 1: import os from concurrent.futures import ThreadPoolExecutor n_threads = min(max_concurrent, len(batch), os.cpu_count() or 4) with ThreadPoolExecutor(max_workers=n_threads) as pool: for sub_start in range(0, len(batch), n_threads): sub = batch[sub_start:sub_start + n_threads] tiles = list(pool.map(_do_one, sub)) for job, tile in zip(sub, tiles): ro, co, rchunk, cchunk, _ = job results.append((ro, co, tile)) del tiles else: for job in batch: ro, co, rchunk, cchunk, _ = job tile = _do_one(job) results.append((ro, co, tile)) del tile return results def _reproject_streaming( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, tile_size, max_memory_bytes, x_desc=False, band_nodata=None, ): """Streaming reproject for datasets too large for dask's graph. Two modes: 1. **Local** (no dask.distributed): ThreadPoolExecutor within one process, bounded by max_memory. 2. **Distributed** (dask.distributed active): creates a dask.bag with one partition per worker, each partition processes its tile batch using threads. Graph size: O(n_workers), not O(n_tiles). Memory usage per worker: bounded by max_memory. """ if isinstance(tile_size, int): tile_size = (tile_size, tile_size) row_chunks, col_chunks = _compute_chunk_layout(out_shape, tile_size) tile_mem = tile_size[0] * tile_size[1] * 8 * 4 # ~4 arrays per tile # Build tile job list jobs = [] row_offset = 0 for rchunk in row_chunks: col_offset = 0 for cchunk in col_chunks: cb = _chunk_bounds( out_bounds, out_shape, row_offset, row_offset + rchunk, col_offset, col_offset + cchunk, ) jobs.append((row_offset, col_offset, rchunk, cchunk, cb)) col_offset += cchunk row_offset += rchunk # Check if dask.distributed is active _use_distributed = False try: from dask.distributed import get_client client = get_client() n_distributed_workers = len(client.scheduler_info()['workers']) if n_distributed_workers > 0: _use_distributed = True except (ImportError, ValueError): pass if _use_distributed and len(jobs) > n_distributed_workers: # Distributed: partition tiles across workers via dask.bag import dask.bag as db # Split jobs into N partitions (one per worker) n_parts = min(n_distributed_workers, len(jobs)) batch_size = math.ceil(len(jobs) / n_parts) batches = [jobs[i:i + batch_size] for i in range(0, len(jobs), batch_size)] # Create bag and map the batch processor bag = db.from_sequence(batches, npartitions=len(batches)) results_bag = bag.map( _process_tile_batch, source_data=raster.data, src_bounds=src_bounds, src_shape=src_shape, y_desc=y_desc, src_wkt=src_wkt, tgt_wkt=tgt_wkt, resampling=resampling, nodata=nodata, precision=precision, max_memory_bytes=max_memory_bytes, tile_mem=tile_mem, x_desc=x_desc, band_nodata=band_nodata, ) # Compute all partitions and assemble result result = np.full(out_shape, nodata, dtype=np.float64) for batch_results in results_bag.compute(): for ro, co, tile in batch_results: result[ro:ro + tile.shape[0], co:co + tile.shape[1]] = tile return result # Local: ThreadPoolExecutor within one process result = np.full(out_shape, nodata, dtype=np.float64) batch_results = _process_tile_batch( jobs, raster.data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, resampling, nodata, precision, max_memory_bytes, tile_mem, x_desc=x_desc, band_nodata=band_nodata, ) for ro, co, tile in batch_results: result[ro:ro + tile.shape[0], co:co + tile.shape[1]] = tile return result def _reproject_dask_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size, x_desc=False, band_nodata=None, ): """Dask+CuPy backend: process output chunks on GPU. Two modes based on available GPU memory: **Fast path** (output fits in GPU VRAM): pre-allocates the full output on GPU and fills it chunk-by-chunk. ~22x faster than the map_blocks path because CRS/transformer objects are created once and CUDA kernels run with minimal launch overhead. **Chunked fallback** (output exceeds GPU VRAM): delegates to ``_reproject_dask(is_cupy=True)`` which uses ``map_blocks`` and processes one chunk at a time with O(chunk_size) GPU memory. """ import cupy as cp from ._crs_utils import _crs_from_wkt src_crs = _crs_from_wkt(src_wkt) tgt_crs = _crs_from_wkt(tgt_wkt) # Use larger chunks for GPU to amortize kernel launch overhead gpu_chunk = chunk_size or 2048 if isinstance(gpu_chunk, int): gpu_chunk = (gpu_chunk, gpu_chunk) row_chunks, col_chunks = _compute_chunk_layout(out_shape, gpu_chunk) out_h, out_w = out_shape src_left, src_bottom, src_right, src_top = src_bounds src_h, src_w = src_shape src_res_x = (src_right - src_left) / src_w src_res_y = (src_top - src_bottom) / src_h # 3-D source: the fast path's inline loop assumes 2-D windows. # Delegate to the map_blocks path which handles 3-D via # _reproject_chunk_cupy's per-band loop (#2027). if raster.data.ndim == 3: return _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size or 2048, True, # is_cupy=True x_desc=x_desc, band_nodata=band_nodata, ) # Memory check: if the full output doesn't fit in GPU memory, # fall back to the map_blocks path which is O(chunk_size) memory. estimated = out_shape[0] * out_shape[1] * 8 # float64 try: free_gpu, _total = cp.cuda.Device().mem_info fits_in_gpu = estimated < 0.5 * free_gpu except (AttributeError, RuntimeError): fits_in_gpu = False if not fits_in_gpu: import warnings warnings.warn( f"Output ({estimated / 1e9:.1f} GB) exceeds GPU memory; " f"falling back to chunked map_blocks path.", stacklevel=3, ) return _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size or 2048, True, # is_cupy=True x_desc=x_desc, band_nodata=band_nodata, ) # Match the dask+numpy and chunked dask+cupy paths: integer sources # round-trip back to their original dtype after clamping; floats stay # float64. Without this, the eager dask+cupy fast path silently # promoted int16/uint8/etc. inputs to float64 while the other # backends preserved the source dtype (#2505). src_dtype = np.dtype(raster.dtype) if np.issubdtype(src_dtype, np.integer): out_dtype = src_dtype else: out_dtype = np.dtype(np.float64) result = cp.full(out_shape, nodata, dtype=out_dtype) row_offset = 0 for i, rchunk in enumerate(row_chunks): col_offset = 0 for j, cchunk in enumerate(col_chunks): cb = _chunk_bounds( out_bounds, out_shape, row_offset, row_offset + rchunk, col_offset, col_offset + cchunk, ) chunk_shape = (rchunk, cchunk) # CUDA coordinate transform (reuses cached CRS objects). # precision == 0 forces the exact pyproj path, so skip CUDA. cuda_coords = None if precision != 0: try: from ._projections_cuda import try_cuda_transform cuda_coords = try_cuda_transform( src_crs, tgt_crs, cb, chunk_shape, ) except (ImportError, ModuleNotFoundError): cuda_coords = None if cuda_coords is not None: src_y, src_x = cuda_coords if x_desc: src_col_px = (src_right - src_x) / src_res_x - 0.5 else: src_col_px = (src_x - src_left) / src_res_x - 0.5 if y_desc: src_row_px = (src_top - src_y) / src_res_y - 0.5 else: src_row_px = (src_y - src_bottom) / src_res_y - 0.5 # Batch the four reductions into a single device-to-host # transfer instead of four separate synchronous .get() calls. mins_maxes = cp.stack([ cp.nanmin(src_row_px), cp.nanmax(src_row_px), cp.nanmin(src_col_px), cp.nanmax(src_col_px), ]) r_min_val, r_max_val, c_min_val, c_max_val = ( float(v) for v in mins_maxes.get() ) if not (np.isfinite(r_min_val) and np.isfinite(r_max_val) and np.isfinite(c_min_val) and np.isfinite(c_max_val)): col_offset += cchunk continue r_min = int(np.floor(r_min_val)) - 2 r_max = int(np.ceil(r_max_val)) + 3 c_min = int(np.floor(c_min_val)) - 2 c_max = int(np.ceil(c_max_val)) + 3 else: # CPU fallback for this chunk from ._crs_utils import _require_pyproj pyproj = _require_pyproj() transformer = pyproj.Transformer.from_crs( tgt_crs, src_crs, always_xy=True ) src_y, src_x = _transform_coords( transformer, cb, chunk_shape, precision, src_crs=src_crs, tgt_crs=tgt_crs, ) if x_desc: src_col_px = (src_right - src_x) / src_res_x - 0.5 else: src_col_px = (src_x - src_left) / src_res_x - 0.5 if y_desc: src_row_px = (src_top - src_y) / src_res_y - 0.5 else: src_row_px = (src_y - src_bottom) / src_res_y - 0.5 r_min = np.nanmin(src_row_px) r_max = np.nanmax(src_row_px) c_min = np.nanmin(src_col_px) c_max = np.nanmax(src_col_px) if not np.isfinite(r_min) or not np.isfinite(r_max): col_offset += cchunk continue if not np.isfinite(c_min) or not np.isfinite(c_max): col_offset += cchunk continue r_min = int(np.floor(r_min)) - 2 r_max = int(np.ceil(r_max)) + 3 c_min = int(np.floor(c_min)) - 2 c_max = int(np.ceil(c_max)) + 3 # Check overlap if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0: col_offset += cchunk continue r_min_clip = max(0, r_min) r_max_clip = min(src_h, r_max) c_min_clip = max(0, c_min) c_max_clip = min(src_w, c_max) # Guard: cap source window to prevent GPU OOM (matches numpy path) _MAX_WINDOW_PIXELS = 64 * 1024 * 1024 # 64 Mpix win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip) if win_pixels > _MAX_WINDOW_PIXELS: col_offset += cchunk continue # Fetch only the needed source window from dask window = raster.data[r_min_clip:r_max_clip, c_min_clip:c_max_clip] if hasattr(window, 'compute'): window = window.compute() if not isinstance(window, cp.ndarray): window = cp.asarray(window) window = window.astype(cp.float64) if not np.isnan(nodata): window[window == nodata] = cp.nan local_row = src_row_px - r_min_clip local_col = src_col_px - c_min_clip # Always use the native CUDA kernels (numpy parity). The # cupyx.scipy.ndimage.map_coordinates fallback diverged from # numpy at chunk edges and for cubic resampling (#2620). # local_row/local_col may be numpy here (pyproj fallback); # _resample_cupy_native transfers them to the GPU. chunk_data = _resample_cupy_native( window, local_row, local_col, resampling=resampling, nodata=nodata, ) # Clamp + cast back for integer source dtypes so this fast # path returns the same dtype as the other backends (#2505). # Matches the per-chunk cast in _reproject_chunk_cupy. if np.issubdtype(out_dtype, np.integer): info = np.iinfo(out_dtype) chunk_data = cp.clip( cp.round(chunk_data), info.min, info.max, ).astype(out_dtype) result[row_offset:row_offset + rchunk, col_offset:col_offset + cchunk] = chunk_data col_offset += cchunk row_offset += rchunk return result def _finite_pair_bbox(tx, ty): """Bounding box of (tx, ty) pairs where both coordinates are finite. The x and y coordinates must be filtered together: a transform can send some probe points to NaN/inf, and dropping finite x and finite y independently would mix coordinates from different points into one box. Returns ``(left, bottom, right, top)`` or ``None`` when no pair is finite in both coordinates. """ tx = np.asarray(tx, dtype=np.float64) ty = np.asarray(ty, dtype=np.float64) mask = np.isfinite(tx) & np.isfinite(ty) if not mask.any(): return None tx = tx[mask] ty = ty[mask] return (float(tx.min()), float(ty.min()), float(tx.max()), float(ty.max())) def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt): """Compute approximate bounding box of source raster in target CRS.""" try: from ._crs_utils import _crs_from_wkt, _resolve_crs try: src_crs = _crs_from_wkt(src_wkt) except Exception: src_crs = _resolve_crs(src_wkt) try: tgt_crs = _crs_from_wkt(tgt_wkt) except Exception: tgt_crs = _resolve_crs(tgt_wkt) except Exception: return None sl, sb, sr, st = src_bounds mx = (sl + sr) / 2 my = (sb + st) / 2 xs = np.array([sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]) ys = np.array([sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]) try: from ._projections import transform_points result = transform_points(src_crs, tgt_crs, xs, ys) if result is not None: tx, ty = result return _finite_pair_bbox(tx, ty) except (ImportError, ModuleNotFoundError): pass try: from ._crs_utils import _require_pyproj pyproj = _require_pyproj() transformer = pyproj.Transformer.from_crs(src_crs, tgt_crs, always_xy=True) tx, ty = transformer.transform(xs.tolist(), ys.tolist()) return _finite_pair_bbox(tx, ty) except Exception: return None def _bounds_overlap(a, b): """Return True if bounding boxes *a* and *b* overlap.""" return a[0] < b[2] and a[2] > b[0] and a[1] < b[3] and a[3] > b[1] def _reproject_block_adapter( block, block_info, source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, is_cupy, src_footprint_tgt, n_bands=None, x_desc=False, band_nodata=None, ): """``map_blocks`` adapter for reprojection. Derives chunk bounds from *block_info* and delegates to the per-chunk worker. For 3-D sources the template carries the band axis, so each block is ``(rh, rw, n_bands)``. The adapter strips the trailing band axis when computing 2-D chunk bounds and the per-chunk worker returns a 3-D result that fits the template (#2027). """ info = block_info[0] # 3-D template -> array-location is 3 entries; spatial dims are the # first two. Band dim spans the full axis (single chunk). spatial_loc = info['array-location'][:2] (row_start, row_end), (col_start, col_end) = spatial_loc chunk_shape = (row_end - row_start, col_end - col_start) cb = _chunk_bounds(out_bounds, out_shape, row_start, row_end, col_start, col_end) is_3d = n_bands is not None # Skip chunks that don't overlap the source footprint if src_footprint_tgt is not None and not _bounds_overlap(cb, src_footprint_tgt): if is_3d: empty_shape = (*chunk_shape, n_bands) if is_cupy: import cupy as cp return cp.full(empty_shape, nodata, dtype=cp.float64) return np.full(empty_shape, nodata, dtype=np.float64) return np.full(chunk_shape, nodata, dtype=np.float64) chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy return chunk_fn( source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, cb, chunk_shape, resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) def _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size, is_cupy, x_desc=False, band_nodata=None, ): """Dask+NumPy backend: ``map_blocks`` over a template array. Uses a single ``blockwise`` layer in the HighLevelGraph instead of O(N) ``dask.delayed`` nodes, keeping graph metadata O(1). The source dask array is bound to the adapter via ``functools.partial`` rather than passed as a ``map_blocks`` kwarg. This prevents dask from adding the full source as a dependency of every output block (which would cause a MemoryError on distributed schedulers when the source exceeds the worker memory limit). For 3-D sources with shape ``(H, W, n_bands)`` the template is built as ``(out_H, out_W, n_bands)`` so the lazy metadata matches the actual chunk output shape (#2027). Without this, the lazy DataArray advertised 2-D shape while the underlying chunks produced 3-D arrays, causing a ``ValueError: replacement data has shape ...`` crash on ``.compute()``. """ import functools import dask.array as da row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size) # Precompute source footprint in target CRS for empty-chunk skipping src_footprint_tgt = _source_footprint_in_target( src_bounds, src_wkt, tgt_wkt ) # Detect 3-D source: chunks will return (rh, rw, n_bands) so the # template must carry the band axis through to the lazy DataArray. is_3d = raster.data.ndim == 3 n_bands = raster.data.shape[2] if is_3d else None # Bind the source dask array and all scalar params via partial so # map_blocks doesn't detect them as dask Array kwargs (which would # add the full source as a dependency of every output block). bound_adapter = functools.partial( _reproject_block_adapter, source_data=raster.data, src_bounds=src_bounds, src_shape=src_shape, y_desc=y_desc, src_wkt=src_wkt, tgt_wkt=tgt_wkt, out_bounds=out_bounds, out_shape=out_shape, resampling=resampling, nodata=nodata, precision=precision, is_cupy=is_cupy, src_footprint_tgt=src_footprint_tgt, n_bands=n_bands, x_desc=x_desc, band_nodata=band_nodata, ) # Pick the template dtype to match the eager path: integer sources # round-trip back to their original dtype after clamping; floats stay # float64. Without this, dask claims float64 meta but the chunks # actually return the integer dtype, producing inconsistent output. src_dtype = np.dtype(raster.dtype) if np.issubdtype(src_dtype, np.integer): out_dtype = src_dtype else: out_dtype = np.dtype(np.float64) if is_3d: # Band axis is one chunk in the template regardless of how the # source dask array is chunked along its band axis. The per-block # worker reads all bands together (via a 2-D y/x slice that # rejoins band chunks on compute) and emits the full band stack # for its (rh, rw) tile, so multi-chunk output along bands would # never get filled. template = da.empty( (*out_shape, n_bands), dtype=out_dtype, chunks=(row_chunks, col_chunks, (n_bands,)), ) else: template = da.empty( out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks) ) return da.map_blocks( bound_adapter, template, dtype=out_dtype, meta=np.array((), dtype=out_dtype), ) # --------------------------------------------------------------------------- # merge() # ---------------------------------------------------------------------------
[docs] def merge( rasters, *, target_crs=None, resolution=None, bounds=None, resampling='bilinear', nodata=None, strategy='first', chunk_size=None, transform_precision=16, name=None, bounds_policy="auto", ): """Merge multiple rasters into a single mosaic. Each input is reprojected to the target CRS (if needed) and placed into a unified output grid. Overlapping regions are resolved using the selected *strategy*. Parameters ---------- rasters : list of xr.DataArray Input rasters to merge. target_crs : optional Target CRS. Defaults to the CRS of the first raster. resolution : float or (float, float) or None Output resolution in target CRS units. bounds : (left, bottom, right, top) or None Explicit output extent. resampling : str Interpolation method: 'nearest', 'bilinear', 'cubic'. nodata : float or None Nodata value for the output. strategy : str Merge strategy: 'first', 'last', 'mean', 'max', 'min'. chunk_size : int or (int, int) or None Output chunk size for dask. If None, defaults to 512 for the standard dask path and 2048 for the in-memory streaming and dask+cupy paths (chosen to amortize kernel launch overhead). transform_precision : int Control-grid subdivisions for the coordinate transform (default 16). Higher values increase accuracy at the cost of more pyproj calls. Set to 0 for exact per-pixel transforms matching GDAL/rasterio. name : str or None Name for the output DataArray. bounds_policy : {"auto", "raw", "clamp", "percentile"}, default "auto" How to derive the unified output extent from the input rasters when ``bounds`` is not supplied. See :func:`reproject` for the full description of each option. Ignored when ``bounds`` is given. Returns ------- xr.DataArray The output y coordinate is always emitted in descending order (top-down, north-up) regardless of the input direction. This matches the standard raster convention and the output of common GIS libraries. Non-spatial coords from the first raster (such as a scalar ``time`` coord) are carried through to the output. Coords aligned to the spatial dims are dropped because their values do not apply to the merged grid. Notes ----- There is no streaming in-memory path; for very large output mosaics, pass dask-backed inputs (or rely on the automatic promotion to the dask path) so that each output chunk is computed independently. Examples -------- >>> import xarray as xr >>> import numpy as np >>> from xrspatial.reproject import merge >>> tile_a = xr.DataArray( ... np.full((32, 32), 1.0), dims=['y', 'x'], ... coords={'y': np.linspace(50, 45, 32), ... 'x': np.linspace(-5, 0, 32)}, ... attrs={'crs': 'EPSG:4326'}, ... ) >>> tile_b = xr.DataArray( ... np.full((32, 32), 2.0), dims=['y', 'x'], ... coords={'y': np.linspace(50, 45, 32), ... 'x': np.linspace(0, 5, 32)}, ... attrs={'crs': 'EPSG:4326'}, ... ) >>> mosaic = merge([tile_a, tile_b], resolution=0.5) >>> mosaic.shape[1] >= 32 True """ if not rasters: raise ValueError("merge(): rasters list must not be empty") for i, r in enumerate(rasters): # merge() only supports 2-D rasters: the merge strategies, same-CRS # placement, and output DataArray construction all assume (y, x). # The 3-D (y, x, band) path was never implemented end-to-end and # crashed at DataArray construction with a dims-vs-shape mismatch # (#2027). Reject 3-D up front so callers get a clear error. _validate_raster(r, func_name='merge', name=f'rasters[{i}]', ndim=(2,)) # Each input must have strictly monotonic, regularly spaced spatial # coords. _source_bounds() infers pixel size from only the first # two samples, so irregular inputs would otherwise silently produce # wrong georeferencing (#2184). _validate_source_coords(r, f'merge[rasters[{i}]]') _validate_grid_params( resolution=resolution, bounds=bounds, width=None, height=None, transform_precision=transform_precision, func_name='merge', ) _validate_resampling(resampling) _validate_strategy(strategy) from ._grid import _validate_bounds_policy _validate_bounds_policy(bounds_policy, func_name='merge') # Resolve target CRS tgt_crs = _resolve_crs(target_crs) if tgt_crs is None: tgt_crs = _detect_source_crs(rasters[0]) if tgt_crs is None: raise ValueError( "Could not detect target CRS. Pass target_crs explicitly." ) # Detect output nodata (the sentinel the user asked for). The merge # output is always float64, so NaN is fine as the output sentinel # when the user didn't supply one -- the integer-cast hazard from # #2185 only applies to the per-input ``r_nd`` values that flow into # the per-raster reproject worker. nd = nodata if nodata is not None else _detect_nodata(rasters[0], nodata) # Gather source info for each raster raster_infos = [] for r in rasters: src_crs = _detect_source_crs(r) if src_crs is None: raise ValueError( f"Could not detect CRS for raster '{r.name}'. " "Ensure all rasters have CRS metadata." ) sb = _source_bounds(r) r_ydim, r_xdim = _find_spatial_dims(r) ss = (r.sizes[r_ydim], r.sizes[r_xdim]) yd = _is_y_descending(r) xd = _is_x_descending(r) # Per-raster input nodata sentinel. Detected independently of the # user-supplied output nodata so that mixed-sentinel inputs are # canonicalized correctly during merge. Pass the raster dtype so # integer sources get an integer-compatible sentinel rather than # NaN -- the same fix as the reproject() path for #2185. r_nd = _detect_nodata(r, None, dtype=r.dtype) raster_infos.append({ 'raster': r, 'src_crs': src_crs, 'src_bounds': sb, 'src_shape': ss, 'y_desc': yd, 'x_desc': xd, 'src_wkt': src_crs.to_wkt(), 'raster_nodata': r_nd, }) # Compute unified output grid if bounds is None: # Union of all raster bounds in target CRS. # Capture bounds_policy warnings during the per-input gather and # emit a single deduplicated message afterwards. Otherwise a # mosaic of N near-antimeridian rasters yields N identical # warnings, which is noise the caller cannot act on individually. import warnings as _warnings all_bounds = [] with _warnings.catch_warnings(record=True) as _caught: _warnings.simplefilter('always', UserWarning) for info in raster_infos: grid = _compute_output_grid( info['src_bounds'], info['src_shape'], info['src_crs'], tgt_crs, resolution=resolution, bounds_policy=bounds_policy, ) all_bounds.append(grid['bounds']) _policy_msgs = [ str(w.message) for w in _caught if issubclass(w.category, UserWarning) and 'bounds_policy' in str(w.message) ] if _policy_msgs: _unique = list(dict.fromkeys(_policy_msgs)) _summary = ( f"merge: bounds_policy={bounds_policy!r} altered the " f"projected extent for {len(_policy_msgs)} input " f"raster(s); {len(_unique)} unique trigger(s). " f"First trigger: {_unique[0]}" ) _warnings.warn(_summary, UserWarning, stacklevel=2) # Re-emit non-bounds_policy warnings that we captured. for _w in _caught: if 'bounds_policy' not in str(_w.message): _warnings.warn_explicit( _w.message, _w.category, _w.filename, _w.lineno, ) left = min(b[0] for b in all_bounds) bottom = min(b[1] for b in all_bounds) right = max(b[2] for b in all_bounds) top = max(b[3] for b in all_bounds) merged_bounds = (left, bottom, right, top) else: merged_bounds = bounds # Use first raster's info for resolution estimation if needed info0 = raster_infos[0] grid = _compute_output_grid( info0['src_bounds'], info0['src_shape'], info0['src_crs'], tgt_crs, resolution=resolution, bounds=merged_bounds, bounds_policy=bounds_policy, ) out_bounds = grid['bounds'] out_shape = grid['shape'] tgt_wkt = tgt_crs.to_wkt() # Detect if any input is dask, or if total size exceeds memory threshold from ..utils import has_dask_array any_dask = False if has_dask_array(): import dask.array as _da any_dask = any(isinstance(r.data, _da.Array) for r in rasters) # Auto-promote to dask path if output would be too large for in-memory merge if not any_dask: out_nbytes = out_shape[0] * out_shape[1] * 8 * len(rasters) # float64 per tile _OOM_THRESHOLD = 512 * 1024 * 1024 if out_nbytes > _OOM_THRESHOLD: any_dask = True if any_dask: result_data = _merge_dask( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nd, strategy, chunk_size, transform_precision, ) else: result_data = _merge_inmemory( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nd, strategy, transform_precision, ) y_coords, x_coords = _make_output_coords(out_bounds, out_shape) ydim, xdim = _find_spatial_dims(rasters[0]) out_coords = {ydim: y_coords, xdim: x_coords} # Carry forward non-spatial coords from the first raster (e.g. scalar # 'time' coord). Skip coords aligned to the rebuilt spatial dims # because their values do not apply to the new grid. for cname, cval in rasters[0].coords.items(): if cname in (ydim, xdim): continue if ydim in cval.dims or xdim in cval.dims: continue out_coords[cname] = cval # Carry the first raster's attrs forward (matches the default # strategy='first'). Drop the duplicate `crs_wkt` and re-emit # `transform` and `res` for the new output grid (rasterio 6-tuple). out_left, _out_bottom, _out_right, out_top = grid['bounds'] out_res_x = grid['res_x'] out_res_y = grid['res_y'] out_attrs = {**rasters[0].attrs} out_attrs.pop('crs_wkt', None) out_attrs['crs'] = tgt_wkt out_attrs['nodata'] = nd out_attrs['res'] = (out_res_x, out_res_y) out_attrs['transform'] = ( out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top, ) # If the first raster used `_FillValue`, propagate the resolved # nodata to both keys for consistent round-trip. Leave absent # otherwise. if '_FillValue' in rasters[0].attrs: out_attrs['_FillValue'] = nd # Same treatment for `nodatavals` (rasterio convention). if 'nodatavals' in rasters[0].attrs: old_nv = rasters[0].attrs['nodatavals'] try: n_entries = max(1, len(old_nv)) except TypeError: n_entries = 1 out_attrs['nodatavals'] = tuple(nd for _ in range(n_entries)) result = xr.DataArray( result_data, dims=[ydim, xdim], coords=out_coords, name=name or rasters[0].name or 'merged', attrs=out_attrs, ) return result
def _place_same_crs(src_data, src_bounds, src_shape, y_desc, out_bounds, out_shape, nodata, x_desc=False): """Place a same-CRS tile into the output grid by coordinate alignment. No reprojection needed -- just index the output rows/columns that overlap with the source tile and copy the data. The output grid is always north-up with ascending x (row 0 is the top of ``out_bounds``, column 0 is the left edge). When ``y_desc`` is ``False`` the source is y-ascending (row 0 is the bottom of ``src_bounds``); in that case the source window is flipped along y before being written so the placed data has the same north-up orientation as the rest of the output. Without this, ``merge([r])`` of a y-ascending raster silently differs from ``reproject(r, target_crs=r.crs)`` (#2186). ``x_desc`` mirrors the same logic on the horizontal axis: when True, ``src_data`` is laid out with column 0 at the maximum x, so the source window is reversed along its column axis before placement (#2183). """ out_h, out_w = out_shape src_h, src_w = src_shape o_left, o_bottom, o_right, o_top = out_bounds s_left, s_bottom, s_right, s_top = src_bounds o_res_x = (o_right - o_left) / out_w o_res_y = (o_top - o_bottom) / out_h s_res_x = (s_right - s_left) / src_w s_res_y = (s_top - s_bottom) / src_h # Output pixel range that this tile covers. The output is always # north-up so ``row_start`` is measured from ``o_top`` downward. col_start = int(round((s_left - o_left) / o_res_x)) col_end = int(round((s_right - o_left) / o_res_x)) row_start = int(round((o_top - s_top) / o_res_y)) row_end = int(round((o_top - s_bottom) / o_res_y)) # Clip to output bounds col_start_clip = max(0, col_start) col_end_clip = min(out_w, col_end) row_start_clip = max(0, row_start) row_end_clip = min(out_h, row_end) if col_start_clip >= col_end_clip or row_start_clip >= row_end_clip: return np.full(out_shape, nodata, dtype=np.float64) # Resolutions may differ slightly; if close enough, do direct copy res_ratio_x = s_res_x / o_res_x res_ratio_y = s_res_y / o_res_y if abs(res_ratio_x - 1.0) > 0.01 or abs(res_ratio_y - 1.0) > 0.01: return None # resolutions too different, fall back to reproject # Source column offset (columns always run left-to-right). src_col_start = col_start_clip - col_start src_c_end = min(src_col_start + (col_end_clip - col_start_clip), src_w) actual_cols = src_c_end - src_col_start out_data = np.full(out_shape, nodata, dtype=np.float64) if actual_cols <= 0: return out_data # Resolve the source column slice. When ``x_desc`` is True the # source has column 0 at max x, so we read a mirrored slice and # reverse it; otherwise we read the natural left-to-right slice. # ``src_col_start`` / ``src_c_end`` were computed above in # ascending-x source space and stay valid for both branches. if x_desc: c_lo = src_w - src_c_end c_hi = src_w - src_col_start else: c_lo = src_col_start c_hi = src_c_end if y_desc: # Source is north-up: rows go top-to-bottom, same as output. # Output row R corresponds to source row ``R - row_start``. src_row_start = row_start_clip - row_start src_r_end = min(src_row_start + (row_end_clip - row_start_clip), src_h) actual_rows = src_r_end - src_row_start if actual_rows <= 0: return out_data src_window = np.asarray( src_data[src_row_start:src_r_end, c_lo:c_hi], dtype=np.float64, ) if x_desc: src_window = src_window[:, ::-1] out_row_start = row_start_clip else: # Source is south-up: source row 0 sits at the bottom of # ``src_bounds``. Output row R corresponds to source row # ``row_end - 1 - R`` (with R and row_end both measured from # ``o_top`` downward, so larger R means lower latitude and # smaller source row index). Concretely, if src_h=4 and # row_end=4, output row 0 maps to source row 3, output row 3 # maps to source row 0 -- a vertical flip. Read a contiguous # ascending slice of source rows and reverse it so the # placed window comes out north-up. src_lo = max(0, row_end - row_end_clip) src_hi = min(src_h, row_end - row_start_clip) actual_rows = src_hi - src_lo if actual_rows <= 0: return out_data src_window = np.asarray( src_data[src_lo:src_hi, c_lo:c_hi], dtype=np.float64, )[::-1, :] if x_desc: src_window = src_window[:, ::-1] # The reversed slice's first row corresponds to source row # ``src_hi - 1``, which maps to output row ``row_end - src_hi``. # When the source is not clipped at the top, this equals # ``row_start_clip``; clipping at the source top shifts the # placement downward by the clipped row count. out_row_start = row_end - src_hi out_data[out_row_start:out_row_start + actual_rows, col_start_clip:col_start_clip + actual_cols] = src_window return out_data def _merge_inmemory( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, transform_precision, ): """In-memory merge using numpy. Detects same-CRS tiles and uses fast direct placement instead of reprojection. Each raster is reprojected using its own nodata sentinel and then canonicalized to NaN before the strategy merge so that mixed-sentinel inputs do not leak invalid pixels into the mosaic. After merging, the NaN canonical sentinel is converted back to the user-requested output ``nodata``. """ from ._crs_utils import _crs_from_wkt tgt_crs = _crs_from_wkt(tgt_wkt) arrays = [] for info in raster_infos: r_nd = info.get('raster_nodata', float('nan')) # Check if source CRS matches target (no reprojection needed) placed = None if info['src_crs'] == tgt_crs: placed = _place_same_crs( info['raster'].values, info['src_bounds'], info['src_shape'], info['y_desc'], out_bounds, out_shape, r_nd, x_desc=info.get('x_desc', False), ) if placed is not None: arr = placed else: arr = _reproject_chunk_numpy( info['raster'].values, info['src_bounds'], info['src_shape'], info['y_desc'], info['src_wkt'], tgt_wkt, out_bounds, out_shape, resampling, r_nd, transform_precision, source_x_desc=info.get('x_desc', False), ) # Canonicalize this raster's sentinel to NaN before the merge so # rasters with different sentinels merge correctly. if not np.isnan(r_nd): arr = np.asarray(arr, dtype=np.float64) arr = np.where(arr == r_nd, np.nan, arr) arrays.append(arr) merged = _merge_arrays_numpy(arrays, float('nan'), strategy) if not np.isnan(nodata): merged = np.where(np.isnan(merged), nodata, merged) return merged def _merge_block_adapter( block, block_info, raster_data_list, src_bounds_list, src_shape_list, y_desc_list, src_wkt_list, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, precision, src_footprints_tgt, raster_nodata_list, same_crs_list, x_desc_list=None, ): """``map_blocks`` adapter for merge. Each raster is reprojected with its own input nodata sentinel and canonicalized to NaN before the strategy merge. The final result is converted back to the user-requested output ``nodata``. """ info = block_info[0] (row_start, row_end), (col_start, col_end) = info['array-location'] chunk_shape = (row_end - row_start, col_end - col_start) cb = _chunk_bounds(out_bounds, out_shape, row_start, row_end, col_start, col_end) # Only reproject rasters whose footprint overlaps this chunk arrays = [] for i in range(len(raster_data_list)): if (src_footprints_tgt[i] is not None and not _bounds_overlap(cb, src_footprints_tgt[i])): continue r_nd = raster_nodata_list[i] placed = None xd_i = x_desc_list[i] if x_desc_list is not None else False if same_crs_list[i]: # Same-CRS path: direct pixel placement (no resampling). # Pass the dask array straight through -- _place_same_crs # slices before np.asarray(), so np.asarray on the slice # materializes only the source window for this output chunk. # An eager .compute() here would materialize the full source # per output chunk, amplifying driver-side data flow by O(N). placed = _place_same_crs( raster_data_list[i], src_bounds_list[i], src_shape_list[i], y_desc_list[i], cb, chunk_shape, r_nd, x_desc=xd_i, ) if placed is not None: arr = placed else: arr = _reproject_chunk_numpy( raster_data_list[i], src_bounds_list[i], src_shape_list[i], y_desc_list[i], src_wkt_list[i], tgt_wkt, cb, chunk_shape, resampling, r_nd, precision, source_x_desc=xd_i, ) if not np.isnan(r_nd): arr = np.asarray(arr, dtype=np.float64) arr = np.where(arr == r_nd, np.nan, arr) arrays.append(arr) if not arrays: return np.full(chunk_shape, nodata, dtype=np.float64) merged = _merge_arrays_numpy(arrays, float('nan'), strategy) if not np.isnan(nodata): merged = np.where(np.isnan(merged), nodata, merged) return merged def _merge_dask( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, chunk_size, transform_precision, ): """Dask merge backend using ``map_blocks``.""" import functools import dask.array as da row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size) # Prepare lists for the worker data_list = [info['raster'].data for info in raster_infos] bounds_list = [info['src_bounds'] for info in raster_infos] shape_list = [info['src_shape'] for info in raster_infos] ydesc_list = [info['y_desc'] for info in raster_infos] xdesc_list = [info.get('x_desc', False) for info in raster_infos] wkt_list = [info['src_wkt'] for info in raster_infos] rnodata_list = [ info.get('raster_nodata', float('nan')) for info in raster_infos ] # Precompute source footprints in target CRS footprints = [ _source_footprint_in_target(bounds_list[i], wkt_list[i], tgt_wkt) for i in range(len(raster_infos)) ] # Precompute CRS-equality flags so per-block adapters can shortcut to # direct pixel placement (matches the eager _merge_inmemory path). from ._crs_utils import _crs_from_wkt tgt_crs = _crs_from_wkt(tgt_wkt) same_crs_list = [ bool(_crs_from_wkt(wkt_list[i]) == tgt_crs) for i in range(len(raster_infos)) ] # Bind via partial to prevent map_blocks from adding dask arrays # in data_list as whole-array dependencies. bound_adapter = functools.partial( _merge_block_adapter, raster_data_list=data_list, src_bounds_list=bounds_list, src_shape_list=shape_list, y_desc_list=ydesc_list, src_wkt_list=wkt_list, tgt_wkt=tgt_wkt, out_bounds=out_bounds, out_shape=out_shape, resampling=resampling, nodata=nodata, strategy=strategy, precision=transform_precision, src_footprints_tgt=footprints, raster_nodata_list=rnodata_list, same_crs_list=same_crs_list, x_desc_list=xdesc_list, ) template = da.empty( out_shape, dtype=np.float64, chunks=(row_chunks, col_chunks) ) return da.map_blocks( bound_adapter, template, dtype=np.float64, meta=np.array((), dtype=np.float64), )