Source code for xrspatial.reproject

"""Out-of-core CRS reprojection and multi-raster merge.

Public API
----------
reproject(raster, target_crs, ...)
    Reproject a DataArray to a new coordinate reference system.
merge(rasters, ...)
    Merge multiple DataArrays into a single mosaic.
"""
from __future__ import annotations

import math

import numpy as np
import xarray as xr

from xrspatial.utils import _dask_task_name_kwargs, _validate_raster

from ._crs_utils import _detect_band_nodata, _detect_nodata, _detect_source_crs, _resolve_crs
from ._grid import (_MAX_OUTPUT_PIXELS, _chunk_bounds, _compute_chunk_layout, _compute_output_grid,
                    _make_output_coords, _validate_grid_params)
from ._interpolate import _resample_cupy_native, _resample_numpy, _validate_resampling
from ._itrf import itrf_transform
from ._itrf import list_frames as itrf_frames
from ._merge import _merge_arrays_numpy, _validate_strategy
from ._transform import ApproximateTransform
from ._vertical import (depth_to_ellipsoidal, ellipsoidal_to_depth, ellipsoidal_to_orthometric,
                        geoid_height, geoid_height_raster, orthometric_to_ellipsoidal)

__all__ = [
    'reproject', 'merge',
    'geoid_height', 'geoid_height_raster',
    'ellipsoidal_to_orthometric', 'orthometric_to_ellipsoidal',
    'depth_to_ellipsoidal', 'ellipsoidal_to_depth',
    'itrf_transform', 'itrf_frames',
]


# ---------------------------------------------------------------------------
# Source geometry helpers
# ---------------------------------------------------------------------------

_Y_NAMES = {'y', 'lat', 'latitude', 'Y', 'Lat', 'Latitude'}
_X_NAMES = {'x', 'lon', 'longitude', 'X', 'Lon', 'Longitude'}

# Output byte budget above which merge() auto-promotes an in-memory mosaic
# to the lazy dask path instead of allocating the whole array.
_MERGE_OOM_THRESHOLD = 512 * 1024 * 1024  # 512 MB

# Byte budget above which reproject() auto-promotes an in-memory raster
# (numpy or cupy) to the chunked dask path. Compared against the input
# array size and one float64 output array independently, not against
# total memory: the eager numpy path holds ~7 output-sized float64
# temporaries (coordinate grids, pixel-index grids, result), so
# eager-path peak RSS can reach ~7x this budget. That multiplier is why a
# small input upsampled to a large output exhausted memory long before
# the _MAX_OUTPUT_PIXELS guard tripped (#3267). The same budget gates the
# cupy promotion added in #3281.
_REPROJECT_OOM_THRESHOLD = 512 * 1024 * 1024  # 512 MB

# Map friendly vertical datum tokens to EPSG codes so attrs['vertical_crs']
# from reproject output matches the convention used by xrspatial.geotiff,
# which also writes EPSG ints to attrs['vertical_crs'].
_VERTICAL_DATUM_EPSG = {
    'EGM96': 5773,        # EGM96 height
    'EGM2008': 3855,      # EGM2008 height
    'ellipsoidal': 4979,  # WGS 84 (3D, ellipsoidal height)
}

# Sentinel marking the deprecated ``src_vertical_crs`` / ``tgt_vertical_crs``
# kwargs as "not passed". Distinct from None so we can tell an explicit
# ``src_vertical_crs=None`` apart from the default and only warn when the
# caller actually used the old name.
_DEPRECATED = object()


def _resolve_deprecated_vertical_kwarg(old_name, old_val, new_name, new_val):
    """Map a deprecated vertical-CRS kwarg onto its renamed replacement.

    Emits a ``DeprecationWarning`` when the old name is used and rejects
    passing both the old and new spellings at once.
    """
    if old_val is _DEPRECATED:
        return new_val
    import warnings
    warnings.warn(
        f"reproject(): {old_name!r} is deprecated, use {new_name!r} instead.",
        DeprecationWarning,
        stacklevel=3,
    )
    if new_val is not None:
        raise TypeError(
            f"reproject(): pass either {new_name!r} or the deprecated "
            f"{old_name!r}, not both."
        )
    return old_val


def _find_spatial_dims(raster):
    """Find the y and x dimension names, handling multi-band rasters.

    Returns (ydim, xdim).  Checks dim names first, falls back to
    assuming the last two non-band dims are spatial.
    """
    dims = raster.dims
    ydim = xdim = None
    for d in dims:
        if d in _Y_NAMES:
            ydim = d
        elif d in _X_NAMES:
            xdim = d
    if ydim is not None and xdim is not None:
        return ydim, xdim
    # Fallback: last two dims
    return dims[-2], dims[-1]


# Default tolerance for the regular-spacing check. Coordinates loaded from
# real GeoTIFFs can drift a few ULPs from perfectly uniform after pixel-to-
# world transforms, so 1e-6 (relative) is loose enough to accept those while
# still catching the single-pixel perturbation case in #2184.
_REGULAR_COORD_RTOL = 1e-6


def _validate_regular_axis(coords, axis_name, func_name, rtol=_REGULAR_COORD_RTOL):
    """Validate that 1-D coordinate array is strictly monotonic and regular.

    Pixel-resolution math in `_source_bounds` and the chunk workers assumes
    a uniform grid. Without this check, irregular or non-monotonic coords
    silently produce wrong georeferencing (see #2184).

    Parameters
    ----------
    coords : array-like
        1-D coordinate values along one axis.
    axis_name : str
        Name of the axis ("x" or "y") for the error message.
    func_name : str
        Calling function name for the error prefix.
    rtol : float
        Relative tolerance for spacing regularity.

    Raises
    ------
    ValueError
        If coords contain non-finite values, are not strictly monotonic,
        or have spacing that varies by more than ``rtol`` relative to the
        median step.
    """
    arr = np.asarray(coords)
    if arr.ndim != 1:
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' must be 1-D, "
            f"got shape {arr.shape}."
        )
    if arr.size < 2:
        # A single-pixel raster has no spacing to validate; the caller
        # will fall back to res=1.0 in _source_bounds, which is fine.
        return
    if not np.all(np.isfinite(arr)):
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' contains non-finite "
            f"values (NaN or inf)."
        )
    # np.asarray skips the copy when arr is already float64; np.diff promotes
    # ints to int64, which is fine but we want float steps for the median /
    # tolerance math below.
    diffs = np.diff(np.asarray(arr, dtype=np.float64))
    # Strict monotonicity: every step has the same sign and is non-zero.
    # `diffs > 0` AND `diffs < 0` are both False for zero steps (repeated
    # coords), so the combined check rejects them. Do NOT replace this with
    # a sign-only test like `np.all(np.sign(diffs) == np.sign(diffs[0]))` --
    # that variant accepts zero steps and lets a repeated coord through.
    if not (np.all(diffs > 0) or np.all(diffs < 0)):
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' must be strictly "
            f"monotonic (all ascending or all descending). The reproject "
            f"pipeline assumes a uniformly-spaced grid; see #2184."
        )
    median_step = float(np.median(diffs))
    abs_med = abs(median_step)
    deviation = np.abs(diffs - median_step)
    worst = float(np.max(deviation))
    if worst > rtol * abs_med:
        # Report the index of the worst step in the original coords so the
        # caller can locate the offending sample without re-running diff.
        worst_idx = int(np.argmax(deviation))
        raise ValueError(
            f"{func_name}(): coordinate '{axis_name}' is not regularly "
            f"spaced. Median step is {median_step!r}; worst deviation is "
            f"{worst!r} at index {worst_idx} (between {axis_name}[{worst_idx}]"
            f"={float(arr[worst_idx])!r} and {axis_name}[{worst_idx + 1}]"
            f"={float(arr[worst_idx + 1])!r}). The reproject pipeline "
            f"assumes a uniformly-spaced grid; see #2184."
        )


def _validate_source_coords(raster, func_name):
    """Validate both spatial axes of a raster before any reproject work."""
    ydim, xdim = _find_spatial_dims(raster)
    _validate_regular_axis(raster.coords[ydim].values, 'y', func_name)
    _validate_regular_axis(raster.coords[xdim].values, 'x', func_name)


def _source_bounds(raster):
    """Extract (left, bottom, right, top) from a DataArray's coordinates."""
    ydim, xdim = _find_spatial_dims(raster)
    y = raster.coords[ydim].values
    x = raster.coords[xdim].values
    # Compute pixel-edge bounds from pixel-center coords
    if len(y) > 1:
        res_y = abs(float(y[1] - y[0]))
    else:
        res_y = 1.0
    if len(x) > 1:
        res_x = abs(float(x[1] - x[0]))
    else:
        res_x = 1.0
    x_min, x_max = float(np.min(x)), float(np.max(x))
    y_min, y_max = float(np.min(y)), float(np.max(y))
    left = x_min - res_x / 2
    right = x_max + res_x / 2
    bottom = y_min - res_y / 2
    top = y_max + res_y / 2
    return (left, bottom, right, top)


def _is_y_descending(raster):
    """Check if Y axis goes from top (large) to bottom (small)."""
    ydim, _ = _find_spatial_dims(raster)
    y = raster.coords[ydim].values
    if len(y) < 2:
        return True
    return float(y[0]) > float(y[-1])


def _is_x_descending(raster):
    """Check if X axis goes from right (large) to left (small).

    Mirrors :func:`_is_y_descending` for the horizontal axis. The default
    convention for a single-column raster is ascending x (matching
    :func:`_make_output_coords` which always emits ascending x).
    """
    _, xdim = _find_spatial_dims(raster)
    x = raster.coords[xdim].values
    if len(x) < 2:
        return False
    return float(x[0]) > float(x[-1])


# ---------------------------------------------------------------------------
# Per-chunk coordinate transform
# ---------------------------------------------------------------------------

def _transform_coords(transformer, chunk_bounds, chunk_shape,
                      transform_precision, src_crs=None, tgt_crs=None):
    """Compute source CRS coordinates for every output pixel.

    When *transform_precision* is 0, every pixel is transformed through
    pyproj exactly (same strategy as GDAL/rasterio).  Otherwise an
    approximate bilinear control-grid interpolation is used.

    For common CRS pairs (WGS84/NAD83 <-> UTM, WGS84 <-> Web Mercator),
    a Numba JIT fast path bypasses pyproj entirely for ~30x speedup.

    Returns
    -------
    src_y, src_x : ndarray (height, width)
    """
    # Try Numba fast path for common projections.
    # transform_precision == 0 is the documented escape hatch for exact
    # per-pixel pyproj transforms, so skip the approximate fast path then.
    if (transform_precision != 0
            and src_crs is not None and tgt_crs is not None):
        try:
            from ._projections import try_numba_transform
            result = try_numba_transform(
                src_crs, tgt_crs, chunk_bounds, chunk_shape,
            )
            if result is not None:
                return result
        except (ImportError, ModuleNotFoundError):
            pass  # fall through to pyproj

    height, width = chunk_shape
    left, bottom, right, top = chunk_bounds
    res_x = (right - left) / width
    res_y = (top - bottom) / height

    if transform_precision == 0:
        # Exact per-pixel transform via pyproj bulk API.
        # Process in row strips to keep memory bounded and improve
        # cache locality for large rasters.
        out_x_1d = left + (np.arange(width, dtype=np.float64) + 0.5) * res_x
        src_x_out = np.empty((height, width), dtype=np.float64)
        src_y_out = np.empty((height, width), dtype=np.float64)
        strip = 256
        for r0 in range(0, height, strip):
            r1 = min(r0 + strip, height)
            n_rows = r1 - r0
            out_y_strip = top - (np.arange(r0, r1, dtype=np.float64) + 0.5) * res_y
            # Broadcast to (n_rows, width) without allocating a full copy
            sx, sy = transformer.transform(
                np.tile(out_x_1d, n_rows),
                np.repeat(out_y_strip, width),
            )
            src_x_out[r0:r1] = np.asarray(sx, dtype=np.float64).reshape(n_rows, width)
            src_y_out[r0:r1] = np.asarray(sy, dtype=np.float64).reshape(n_rows, width)
        return src_y_out, src_x_out

    # Approximate: bilinear interpolation on a coarse control grid.
    approx = ApproximateTransform(
        transformer, chunk_bounds, chunk_shape,
        precision=transform_precision,
    )
    row_grid = np.arange(height, dtype=np.float64)[:, np.newaxis]
    col_grid = np.arange(width, dtype=np.float64)[np.newaxis, :]
    row_grid = np.broadcast_to(row_grid, (height, width))
    col_grid = np.broadcast_to(col_grid, (height, width))
    return approx(row_grid, col_grid)


# ---------------------------------------------------------------------------
# Per-chunk worker functions
# ---------------------------------------------------------------------------

def _reproject_chunk_numpy(
    source_data, source_bounds_tuple, source_shape, source_y_desc,
    src_wkt, tgt_wkt,
    chunk_bounds_tuple, chunk_shape,
    resampling, nodata, transform_precision,
    source_x_desc=False,
    band_nodata=None,
):
    """Reproject a single output chunk (numpy backend).

    Called inside ``dask.delayed`` for the dask path, or directly for numpy.
    CRS objects are passed as WKT strings for pickle safety.

    ``source_x_desc`` mirrors ``source_y_desc`` for the horizontal axis:
    when True, source column 0 is at the maximum x and column ``src_w-1``
    is at the minimum x. Defaults to False so older callers keep working.
    """
    from ._crs_utils import _crs_from_wkt

    src_crs = _crs_from_wkt(src_wkt)
    tgt_crs = _crs_from_wkt(tgt_wkt)

    # Try Numba fast path first (avoids creating pyproj Transformer).
    # transform_precision == 0 forces the exact pyproj path, so skip Numba.
    numba_result = None
    if transform_precision != 0:
        try:
            from ._projections import try_numba_transform
            numba_result = try_numba_transform(
                src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
            )
        except (ImportError, ModuleNotFoundError):
            pass

    if numba_result is not None:
        src_y, src_x = numba_result
    else:
        # Fallback: create pyproj Transformer (expensive)
        from ._crs_utils import _require_pyproj
        pyproj = _require_pyproj()
        transformer = pyproj.Transformer.from_crs(
            tgt_crs, src_crs, always_xy=True
        )
        # Pass src_crs/tgt_crs as None: the numba fast path was already
        # tried above and returned None, and _transform_coords gates its
        # own try_numba_transform retry on both CRSes being non-None.
        # Re-trying would repeat the CRS param parsing and chunk-sized
        # coordinate allocations for nothing (#3106).
        src_y, src_x = _transform_coords(
            transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
        )

    # Convert source CRS coordinates to source pixel coordinates
    src_left, src_bottom, src_right, src_top = source_bounds_tuple
    src_h, src_w = source_shape
    src_res_x = (src_right - src_left) / src_w
    src_res_y = (src_top - src_bottom) / src_h

    if source_x_desc:
        src_col_px = (src_right - src_x) / src_res_x - 0.5
    else:
        src_col_px = (src_x - src_left) / src_res_x - 0.5
    if source_y_desc:
        src_row_px = (src_top - src_y) / src_res_y - 0.5
    else:
        src_row_px = (src_y - src_bottom) / src_res_y - 0.5

    # Determine source window needed
    r_min = np.nanmin(src_row_px)
    r_max = np.nanmax(src_row_px)
    c_min = np.nanmin(src_col_px)
    c_max = np.nanmax(src_col_px)

    # 3-D source: empty-chunk returns must carry the band axis or the
    # dask map_blocks template (which is 3-D for 3-D sources) sees a
    # shape mismatch (#2027).
    if source_data.ndim == 3:
        _empty_shape = (*chunk_shape, source_data.shape[2])
    else:
        _empty_shape = chunk_shape

    # Empty-chunk fills must match the dtype the data path returns and the
    # dask template advertises (#3096): integer sources round-trip back to
    # their dtype, floats stay float64. Without this, a single no-overlap
    # chunk promoted the whole assembled dask output to float64. The
    # resolved nodata is guaranteed representable for integer rasters
    # (#2185/#2572).
    if np.issubdtype(source_data.dtype, np.integer):
        _empty_dtype = source_data.dtype
    else:
        _empty_dtype = np.float64

    if not np.isfinite(r_min) or not np.isfinite(r_max):
        return np.full(_empty_shape, nodata, dtype=_empty_dtype)
    if not np.isfinite(c_min) or not np.isfinite(c_max):
        return np.full(_empty_shape, nodata, dtype=_empty_dtype)

    r_min = int(np.floor(r_min)) - 2
    r_max = int(np.ceil(r_max)) + 3
    c_min = int(np.floor(c_min)) - 2
    c_max = int(np.ceil(c_max)) + 3

    # Check overlap
    if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
        return np.full(_empty_shape, nodata, dtype=_empty_dtype)

    # Clip to source bounds
    r_min_clip = max(0, r_min)
    r_max_clip = min(src_h, r_max)
    c_min_clip = max(0, c_min)
    c_max_clip = min(src_w, c_max)

    # Guard: cap source window to prevent OOM if projection maps a small
    # output chunk to a huge source area (e.g. polar stereographic edges).
    _MAX_WINDOW_PIXELS = 64 * 1024 * 1024  # 64 Mpix (~512 MB for float64)
    win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
    if win_pixels > _MAX_WINDOW_PIXELS:
        return np.full(_empty_shape, nodata, dtype=_empty_dtype)

    # Extract source window
    window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
    if hasattr(window, 'compute'):
        window = window.compute()
    window = np.asarray(window)
    orig_dtype = window.dtype

    # Adjust coordinates relative to window
    local_row = src_row_px - r_min_clip
    local_col = src_col_px - c_min_clip

    # Multi-band: reproject each band separately, share coordinates
    if window.ndim == 3:
        n_bands = window.shape[2]
        bands = []
        for b in range(n_bands):
            band_data = window[:, :, b].astype(np.float64)
            # Mask this band with its own source sentinel when the raster
            # declares per-band nodata; otherwise fall back to the single
            # resolved sentinel (#2647).
            src_nd = band_nodata[b] if band_nodata is not None else nodata
            if not np.isnan(src_nd):
                band_data[band_data == src_nd] = np.nan
            band_result = _resample_numpy(band_data, local_row, local_col,
                                          resampling=resampling, nodata=nodata)
            if np.issubdtype(orig_dtype, np.integer):
                info = np.iinfo(orig_dtype)
                band_result = np.clip(np.round(band_result), info.min, info.max).astype(orig_dtype)
            bands.append(band_result)
        return np.stack(bands, axis=-1)

    # Single-band path
    window = window.astype(np.float64)

    # Convert sentinel nodata to NaN so numba kernels can detect it
    if not np.isnan(nodata):
        window[window == nodata] = np.nan

    result = _resample_numpy(window, local_row, local_col,
                             resampling=resampling, nodata=nodata)

    # Clamp and cast back for integer source dtypes
    if np.issubdtype(orig_dtype, np.integer):
        info = np.iinfo(orig_dtype)
        result = np.clip(np.round(result), info.min, info.max).astype(orig_dtype)

    return result


def _reproject_chunk_cupy(
    source_data, source_bounds_tuple, source_shape, source_y_desc,
    src_wkt, tgt_wkt,
    chunk_bounds_tuple, chunk_shape,
    resampling, nodata, transform_precision,
    source_x_desc=False,
    band_nodata=None,
):
    """CuPy variant of ``_reproject_chunk_numpy``.

    ``source_x_desc`` carries the horizontal direction flag (same meaning
    as in :func:`_reproject_chunk_numpy`).
    """
    import cupy as cp

    from ._crs_utils import _crs_from_wkt

    src_crs = _crs_from_wkt(src_wkt)
    tgt_crs = _crs_from_wkt(tgt_wkt)

    # 3-D source: empty-chunk returns must carry the band axis to match
    # the dask+cupy map_blocks template (#2027).
    if source_data.ndim == 3:
        _empty_shape = (*chunk_shape, source_data.shape[2])
    else:
        _empty_shape = chunk_shape

    # Empty-chunk fills must match the dtype the data path returns (#3096);
    # see the matching block in _reproject_chunk_numpy.
    if np.issubdtype(source_data.dtype, np.integer):
        _empty_dtype = source_data.dtype
    else:
        _empty_dtype = np.float64

    # Try CUDA transform first (keeps coordinates on-device).
    # transform_precision == 0 forces the exact pyproj path, so skip CUDA.
    cuda_result = None
    if (transform_precision != 0
            and src_crs is not None and tgt_crs is not None):
        try:
            from ._projections_cuda import try_cuda_transform
            cuda_result = try_cuda_transform(
                src_crs, tgt_crs, chunk_bounds_tuple, chunk_shape,
            )
        except (ImportError, ModuleNotFoundError):
            pass

    if cuda_result is not None:
        src_y, src_x = cuda_result  # cupy arrays
        src_left, src_bottom, src_right, src_top = source_bounds_tuple
        src_h, src_w = source_shape
        src_res_x = (src_right - src_left) / src_w
        src_res_y = (src_top - src_bottom) / src_h
        # Pixel coordinate math stays on GPU via cupy operators
        if source_x_desc:
            src_col_px = (src_right - src_x) / src_res_x - 0.5
        else:
            src_col_px = (src_x - src_left) / src_res_x - 0.5
        if source_y_desc:
            src_row_px = (src_top - src_y) / src_res_y - 0.5
        else:
            src_row_px = (src_y - src_bottom) / src_res_y - 0.5
        # Need min/max on CPU for window selection.
        # Stack the four reductions and pull them across in one device-to-host
        # transfer to avoid four separate synchronous syncs.
        mins_maxes = cp.stack([
            cp.nanmin(src_row_px), cp.nanmax(src_row_px),
            cp.nanmin(src_col_px), cp.nanmax(src_col_px),
        ])
        r_min_val, r_max_val, c_min_val, c_max_val = (
            float(v) for v in mins_maxes.get()
        )
        if not (np.isfinite(r_min_val) and np.isfinite(r_max_val)
                and np.isfinite(c_min_val) and np.isfinite(c_max_val)):
            return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
        r_min = int(np.floor(r_min_val)) - 2
        r_max = int(np.ceil(r_max_val)) + 3
        c_min = int(np.floor(c_min_val)) - 2
        c_max = int(np.ceil(c_max_val)) + 3
        # Coordinates stay as CuPy arrays for native CUDA resampling
    else:
        # CPU fallback (Numba JIT or pyproj)
        from ._crs_utils import _require_pyproj
        pyproj = _require_pyproj()
        transformer = pyproj.Transformer.from_crs(
            tgt_crs, src_crs, always_xy=True
        )
        src_y, src_x = _transform_coords(
            transformer, chunk_bounds_tuple, chunk_shape, transform_precision,
            src_crs=src_crs, tgt_crs=tgt_crs,
        )

        src_left, src_bottom, src_right, src_top = source_bounds_tuple
        src_h, src_w = source_shape
        src_res_x = (src_right - src_left) / src_w
        src_res_y = (src_top - src_bottom) / src_h

        if source_x_desc:
            src_col_px = (src_right - src_x) / src_res_x - 0.5
        else:
            src_col_px = (src_x - src_left) / src_res_x - 0.5
        if source_y_desc:
            src_row_px = (src_top - src_y) / src_res_y - 0.5
        else:
            src_row_px = (src_y - src_bottom) / src_res_y - 0.5

        r_min = np.nanmin(src_row_px)
        r_max = np.nanmax(src_row_px)
        c_min = np.nanmin(src_col_px)
        c_max = np.nanmax(src_col_px)
        if not np.isfinite(r_min) or not np.isfinite(r_max):
            return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
        if not np.isfinite(c_min) or not np.isfinite(c_max):
            return cp.full(_empty_shape, nodata, dtype=_empty_dtype)
        r_min = int(np.floor(r_min)) - 2
        r_max = int(np.ceil(r_max)) + 3
        c_min = int(np.floor(c_min)) - 2
        c_max = int(np.ceil(c_max)) + 3

    if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0:
        return cp.full(_empty_shape, nodata, dtype=_empty_dtype)

    r_min_clip = max(0, r_min)
    r_max_clip = min(src_h, r_max)
    c_min_clip = max(0, c_min)
    c_max_clip = min(src_w, c_max)

    # Guard: cap source window to prevent GPU OOM if projection maps a
    # small output chunk to a huge source area (matches numpy path).
    _MAX_WINDOW_PIXELS = 64 * 1024 * 1024  # 64 Mpix (~512 MB for float64)
    win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip)
    if win_pixels > _MAX_WINDOW_PIXELS:
        return cp.full(_empty_shape, nodata, dtype=_empty_dtype)

    window = source_data[r_min_clip:r_max_clip, c_min_clip:c_max_clip]
    if hasattr(window, 'compute'):
        window = window.compute()
    if not isinstance(window, cp.ndarray):
        window = cp.asarray(window)
    orig_dtype = window.dtype

    # Adjust coordinates relative to window (stays on GPU if CuPy)
    local_row = src_row_px - r_min_clip
    local_col = src_col_px - c_min_clip

    # Multi-band: reproject each band separately, share coordinates.
    # Matches the 3-D branch in _reproject_chunk_numpy so 3-D inputs work
    # on cupy and dask+cupy backends instead of crashing with a CUDA
    # signature mismatch (#2027).
    if window.ndim == 3:
        n_bands = window.shape[2]
        # The coordinate arrays are shared by every band. On the CPU
        # transform fallback they arrive as numpy; convert them to the
        # device once here, otherwise _resample_cupy_native re-uploads
        # the same two chunk-sized arrays on every band iteration (#3268).
        if not isinstance(local_row, cp.ndarray):
            local_row = cp.asarray(local_row)
        if not isinstance(local_col, cp.ndarray):
            local_col = cp.asarray(local_col)
        bands = []
        for b in range(n_bands):
            band_data = window[:, :, b].astype(cp.float64)
            # Mask this band with its own source sentinel when the raster
            # declares per-band nodata; otherwise fall back to the single
            # resolved sentinel (#2647). Pre-converting to NaN here lets
            # each band use a different source sentinel; the native kernel
            # still fills out-of-bounds pixels with the resolved `nodata`.
            src_nd = band_nodata[b] if band_nodata is not None else nodata
            if not np.isnan(src_nd):
                band_data = cp.where(
                    band_data == src_nd, cp.nan, band_data,
                )
            # Always resample through the native CUDA kernels so the cupy
            # backend matches numpy exactly. They accept CPU coordinate
            # arrays (transferring them to the GPU) and do the
            # nodata->NaN conversion internally, so they serve both the
            # on-device coordinate path and the pyproj fallback. Using
            # cupyx.scipy.ndimage.map_coordinates here instead would
            # diverge from numpy: it bleeds the cval=0.0 constant into the
            # half-pixel boundary band rather than renormalizing, and its
            # order=3 path is a B-spline rather than Catmull-Rom (#2620).
            band_result = _resample_cupy_native(
                band_data, local_row, local_col,
                resampling=resampling, nodata=nodata,
            )
            if np.issubdtype(orig_dtype, np.integer):
                info = np.iinfo(orig_dtype)
                band_result = cp.clip(
                    cp.round(band_result), info.min, info.max,
                ).astype(orig_dtype)
            bands.append(band_result)
        return cp.stack(bands, axis=-1)

    window = window.astype(cp.float64)

    # Always resample through the native CUDA kernels for numpy parity.
    # local_row/local_col may be CuPy (on-device transform) or numpy
    # (pyproj fallback); _resample_cupy_native handles both and does the
    # nodata->NaN conversion internally. The previous
    # cupyx.scipy.ndimage.map_coordinates fallback diverged from numpy at
    # chunk edges and for cubic resampling (#2620).
    result = _resample_cupy_native(window, local_row, local_col,
                                   resampling=resampling, nodata=nodata)

    # Clamp and cast back for integer source dtypes (parity with numpy path)
    if np.issubdtype(orig_dtype, np.integer):
        info = np.iinfo(orig_dtype)
        result = cp.clip(cp.round(result), info.min, info.max).astype(orig_dtype)
    return result


# ---------------------------------------------------------------------------
# reproject()
# ---------------------------------------------------------------------------

[docs] def reproject( raster, target_crs, *, source_crs=None, resolution=None, bounds=None, width=None, height=None, resampling='bilinear', nodata=None, transform_precision=16, chunk_size=None, name=None, max_memory=None, source_vertical_crs=None, target_vertical_crs=None, bounds_policy="auto", src_vertical_crs=_DEPRECATED, tgt_vertical_crs=_DEPRECATED, ): """Reproject a raster DataArray to a new coordinate reference system. Supports numpy, cupy, dask+numpy, and dask+cupy backends. For dask inputs, the computation is fully lazy: each output chunk independently reads only the source pixels it needs. Numpy inputs whose input or output working set exceeds ~512 MB are routed through the same lazy dask path when dask is installed, so the result is dask-backed in that case. Without dask, a streaming fallback bounds memory via ``max_memory``. Parameters ---------- raster : xr.DataArray Input raster with y/x coordinates. target_crs Target CRS in any format accepted by ``pyproj.CRS()``. source_crs : optional Source CRS. Auto-detected from *raster* if None. resolution : float or (float, float) or None Output pixel size in target CRS units. bounds : (left, bottom, right, top) or None Explicit output extent in target CRS. width, height : int or None Explicit output grid dimensions. resampling : str One of 'nearest', 'bilinear', 'cubic'. nodata : float or None Nodata value. Auto-detected if None. For integer input dtypes, an explicit value that does not fit the dtype range raises ``ValueError`` (e.g. ``nodata=-9999`` with a ``uint8`` raster). Attrs/rioxarray-derived out-of-range values emit a ``UserWarning`` and fall back to ``dtype.min`` for signed or ``dtype.max`` for unsigned so legacy files still load (#2572). transform_precision : int Control-grid subdivisions for the coordinate transform (default 16). Higher values increase accuracy at the cost of more pyproj calls. Set to 0 for exact per-pixel transforms matching GDAL/rasterio. chunk_size : int or (int, int) or None Output chunk size for dask. If None, defaults to 512 for the standard dask path and 2048 for the in-memory streaming and dask+cupy paths (chosen to amortize kernel launch overhead). name : str or None Name for the output DataArray. max_memory : int or str or None Maximum memory budget for the reprojection working set. Accepts bytes (int) or human-readable strings like ``'4GB'``, ``'512MB'``. Controls how many output tiles are processed in parallel for large-dataset streaming mode. Default None uses 1GB. Has no effect for small datasets that fit in memory. source_vertical_crs : str or None Source vertical datum for height values. One of: - ``'EGM96'`` -- orthometric heights relative to EGM96 geoid (MSL) - ``'EGM2008'`` -- orthometric heights relative to EGM2008 geoid - ``'ellipsoidal'`` -- heights relative to the WGS84 ellipsoid - ``None`` -- no vertical transformation (default) target_vertical_crs : str or None Target vertical datum. Same options as *source_vertical_crs*. Both must be set to trigger a vertical transformation. src_vertical_crs : str or None Deprecated alias for *source_vertical_crs*. Passing it emits a ``DeprecationWarning``. tgt_vertical_crs : str or None Deprecated alias for *target_vertical_crs*. Passing it emits a ``DeprecationWarning``. bounds_policy : {"auto", "raw", "clamp", "percentile"}, default "auto" How to derive the output extent from the source extent when ``bounds`` is not supplied. Only relevant when projecting near a singularity (antimeridian, pole, projection edge): - ``"raw"``: use the true projected extent of the source corners and edges. No clamp, no percentile, no heuristic. The output may be very large if the input straddles a projection singularity. Use this when you want a true projection of the source extent. - ``"clamp"``: trim geographic source bounds inward by 0.01 deg from +/-180 longitude and +/-90 latitude before projecting. Avoids infinities at singularities. No percentile fallback. No-op on projected source CRSes (UTM, Mercator, etc.) since the clamp only applies in degrees. - ``"percentile"``: project a dense interior grid of the source extent and use the 2nd/98th percentiles of the result as the output bounds. Rejects projection outliers at the cost of trimming valid pixels. - ``"auto"`` (default): apply ``"clamp"`` for geographic source CRSes and fall back to ``"percentile"`` when the projected extent is more than 50x the source extent. Matches the historical behaviour. When ``"auto"``, ``"clamp"``, or ``"percentile"`` actually alters the bounds, a ``UserWarning`` is emitted naming the policy and reporting the per-side delta versus the raw projected bounds. Filter with ``warnings.filterwarnings`` if the crop is intentional. Returns ------- xr.DataArray The output ``attrs['crs']`` is in WKT format. Whenever *target_vertical_crs* is set, ``attrs['vertical_crs']`` records the target vertical datum's EPSG code (5773 for EGM96, 3855 for EGM2008, 4979 for ellipsoidal WGS84) to match the convention used by ``xrspatial.geotiff``. The friendly string token (``'EGM96'`` etc.) is preserved under ``attrs['vertical_datum']``. Both attrs are written even when no shift is applied (e.g. when *source_vertical_crs* equals *target_vertical_crs*, or when only the target is given), so the output's vertical reference is always explicit. The output y coordinate is always emitted in descending order (top-down, north-up) and the output x coordinate is always emitted in ascending order (left-to-right) regardless of the input directions. This matches the standard raster convention and the output of common GIS libraries. Inputs with descending x are detected from the x coordinate values and handled the same way as descending y: the pixel-index mapping is mirrored so the output values stay correct. Non-spatial coords from the input (such as a scalar ``time`` coord or a non-dimension coord that is not aligned to the spatial dims) are carried through to the output. Coords that are aligned to the input y or x dims are dropped because their values do not apply to the rebuilt grid. Examples -------- >>> import xarray as xr >>> import numpy as np >>> from xrspatial.reproject import reproject >>> raster = xr.DataArray( ... np.random.rand(64, 64), ... dims=['y', 'x'], ... coords={'y': np.linspace(50, 40, 64), ... 'x': np.linspace(-5, 5, 64)}, ... attrs={'crs': 'EPSG:4326'}, ... ) >>> result = reproject(raster, 'EPSG:3857') >>> result.attrs['crs'].startswith(('PROJCRS', 'PROJCS')) True """ # Back-compat shim for the old abbreviated kwarg names. These were # renamed to source_vertical_crs / target_vertical_crs to match the # source_crs / target_crs spelling used by the rest of the signature. source_vertical_crs = _resolve_deprecated_vertical_kwarg( 'src_vertical_crs', src_vertical_crs, 'source_vertical_crs', source_vertical_crs) target_vertical_crs = _resolve_deprecated_vertical_kwarg( 'tgt_vertical_crs', tgt_vertical_crs, 'target_vertical_crs', target_vertical_crs) _validate_raster(raster, func_name='reproject', name='raster', ndim=(2, 3)) # Reject irregular / non-monotonic source coords before any CRS # resolution or grid math. _source_bounds() infers pixel size from # only the first two coord samples and downstream pixel math assumes # uniform spacing, so an unchecked irregular input would silently # produce wrong georeferencing (#2184). _validate_source_coords(raster, 'reproject') _validate_grid_params( resolution=resolution, bounds=bounds, width=width, height=height, transform_precision=transform_precision, func_name='reproject', ) _validate_resampling(resampling) from ._grid import _validate_bounds_policy _validate_bounds_policy(bounds_policy, func_name='reproject') # Reject unknown vertical-datum tokens at the API boundary so we never # write None into attrs['vertical_crs'] for typos / unsupported values. for _name, _val in (('source_vertical_crs', source_vertical_crs), ('target_vertical_crs', target_vertical_crs)): if _val is not None and _val not in _VERTICAL_DATUM_EPSG: raise ValueError( f"Unknown {_name}={_val!r}; expected one of " f"{sorted(_VERTICAL_DATUM_EPSG)} or None." ) # Normalize 3-D inputs to canonical (y, x, band) layout. # The per-chunk workers slice the source as ``source_data[r:, c:]`` and # assume the band axis is trailing. A rasterio/rioxarray-style # ``(band, y, x)`` input would otherwise slice the band/y axes instead # of the y/x axes and either crash or return wrong-shape data (#2182). # We record the input's original dim order so the output can be # transposed back at the end, preserving downstream expectations. _input_dims = tuple(raster.dims) if raster.ndim == 3: _ydim_in, _xdim_in = _find_spatial_dims(raster) _band_dims_in = [d for d in _input_dims if d not in (_ydim_in, _xdim_in)] _band_dim_in = _band_dims_in[0] if _band_dims_in else None if _band_dim_in is not None: _canonical = (_ydim_in, _xdim_in, _band_dim_in) if _input_dims != _canonical: raster = raster.transpose(*_canonical) # Resolve CRS src_crs = _resolve_crs(source_crs) if src_crs is None: src_crs = _detect_source_crs(raster) if src_crs is None: raise ValueError( "Could not detect source CRS. Pass source_crs explicitly." ) tgt_crs = _resolve_crs(target_crs) # Detect nodata. Pass the raster dtype so integer rasters get an # integer-compatible sentinel (dtype min for signed, dtype max for # unsigned) instead of NaN. Without this hint, the worker's # cast-back step would collapse NaN to 0 and `attrs['nodata']` # would contradict the array contents (#2185). nd = _detect_nodata(raster, nodata, dtype=raster.dtype) # Multi-band rasters can declare a distinct source sentinel per band # via the rasterio `nodatavals` tuple. `nd` is the single resolved # output sentinel; `band_nd` carries the raw per-band source sentinels # so each band is masked with its own value before resampling (#2647). # `None` means one scalar covers every band -- the workers use `nd`. # The raster is in canonical (y, x, band) layout here, so the band # axis is trailing. _n_bands = raster.shape[2] if raster.ndim == 3 else None band_nd = _detect_band_nodata(raster, nodata, _n_bands) # Source geometry src_bounds = _source_bounds(raster) _ydim, _xdim = _find_spatial_dims(raster) src_shape = (raster.sizes[_ydim], raster.sizes[_xdim]) y_desc = _is_y_descending(raster) x_desc = _is_x_descending(raster) # Detect backend before computing the output grid so the grid's # output-size guard can tell a lazy dask output (never materialized # in full) from a materializing backend. from ..utils import has_dask_array, is_cupy_array data = raster.data is_dask = False if has_dask_array(): import dask.array as _da is_dask = isinstance(data, _da.Array) is_cupy = False if is_dask: # Check underlying type try: from ..utils import is_cupy_backed is_cupy = is_cupy_backed(raster) except (ImportError, ModuleNotFoundError): pass else: is_cupy = is_cupy_array(data) # Compute the output grid with the size guard disabled: whether the # guard applies depends on the final execution path (a lazy dask # output never materializes the full grid), and the path decision # below needs the output shape. The guard is re-applied after the # path is known -- same pattern merge() uses. grid = _compute_output_grid( src_bounds, src_shape, src_crs, tgt_crs, resolution=resolution, bounds=bounds, width=width, height=height, bounds_policy=bounds_policy, lazy_output=True, ) out_bounds = grid['bounds'] out_shape = grid['shape'] # For large workloads, wrap in dask for chunked processing. Both the # input and the output size count: the eager numpy path holds several # output-sized float64 temporaries (coordinate grids, pixel-index # grids, result), so a small input upsampled to a large output can # exhaust memory well before the _MAX_OUTPUT_PIXELS guard trips # (#3267). map_blocks generates an O(1) HighLevelGraph (single # blockwise layer) so graph metadata is no longer a concern -- the # streaming fallback is only needed when dask itself is unavailable. _use_streaming = False if not is_dask: nbytes = src_shape[0] * src_shape[1] * data.dtype.itemsize out_nbytes = out_shape[0] * out_shape[1] * 8 # float64 working set if data.ndim == 3: nbytes *= data.shape[2] out_nbytes *= data.shape[2] if (nbytes > _REPROJECT_OOM_THRESHOLD or out_nbytes > _REPROJECT_OOM_THRESHOLD): cs = chunk_size or 2048 if isinstance(cs, int): cs = (cs, cs) if data.ndim == 3: # Band axis stays one chunk; the per-chunk workers read # all bands of a (y, x) window together. cs = (cs[0], cs[1], data.shape[2]) try: import dask.array as _da # Wrapping an in-memory cupy array keeps the blocks # cupy-backed, so is_cupy stays True and the promoted # raster routes through _reproject_dask_cupy -- which has # its own GPU VRAM check and chunked map_blocks fallback. # Above the threshold an in-memory cupy input therefore # yields a dask-of-cupy output, matching merge()'s # promotion contract (#3281). data = _da.from_array(data, chunks=cs) raster = xr.DataArray( data, dims=raster.dims, coords=raster.coords, name=raster.name, attrs=raster.attrs, ) is_dask = True except ImportError: # dask not available. The streaming fallback is a numpy # CPU path, so it only applies to numpy inputs; a cupy # input keeps the eager GPU path rather than silently # moving its data off the device (#3281). if not is_cupy: _use_streaming = True # Re-apply the output-size guard for paths that materialize the whole # output (in-memory numpy/cupy and streaming). A lazy dask output # skips it because peak memory is bounded by the chunk size. if not is_dask and out_shape[0] * out_shape[1] > _MAX_OUTPUT_PIXELS: raise ValueError( f"Computed output grid is too large ({out_shape[1]} x " f"{out_shape[0]} = {out_shape[0] * out_shape[1]:,} pixels, " f"limit is {_MAX_OUTPUT_PIXELS:,}). Increase the resolution " f"parameter or reduce the output extent." ) # Output coordinates y_coords, x_coords = _make_output_coords(out_bounds, out_shape) # Serialize CRS for pickle safety src_wkt = src_crs.to_wkt() tgt_wkt = tgt_crs.to_wkt() if _use_streaming: result_data = _reproject_streaming( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, chunk_size or 2048, _parse_max_memory(max_memory), x_desc=x_desc, band_nodata=band_nd, ) elif is_dask and is_cupy: result_data = _reproject_dask_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, chunk_size, x_desc=x_desc, band_nodata=band_nd, ) elif is_dask: result_data = _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, chunk_size, False, x_desc=x_desc, band_nodata=band_nd, ) elif is_cupy: result_data = _reproject_inmemory_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, x_desc=x_desc, band_nodata=band_nd, ) else: result_data = _reproject_inmemory_numpy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nd, transform_precision, x_desc=x_desc, band_nodata=band_nd, ) # Vertical datum transformation (if requested) if source_vertical_crs is not None and target_vertical_crs is not None: if source_vertical_crs != target_vertical_crs: result_data, nd = _apply_vertical_shift( result_data, y_coords, x_coords, source_vertical_crs, target_vertical_crs, nd, tgt_crs_wkt=tgt_wkt, ) ydim, xdim = _find_spatial_dims(raster) # Carry input attrs forward so units, long_name, scale_factor, etc. # survive the transform. Re-emit `transform` and `res` for the new # output grid (rasterio 6-tuple convention). Drop `crs_wkt` since it # would duplicate (or contradict) the canonical `crs` we re-emit. out_left, _out_bottom, _out_right, out_top = grid['bounds'] out_res_x = grid['res_x'] out_res_y = grid['res_y'] out_attrs = {**raster.attrs} out_attrs.pop('crs_wkt', None) out_attrs['crs'] = tgt_wkt out_attrs['nodata'] = nd out_attrs['res'] = (out_res_x, out_res_y) out_attrs['transform'] = ( out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top, ) # If the input used `_FillValue`, propagate the resolved nodata to # both keys so downstream consumers reading either find a consistent # value. If `_FillValue` was absent, leave it absent. if '_FillValue' in raster.attrs: out_attrs['_FillValue'] = nd # `nodatavals` (rasterio convention) is a tuple of per-band sentinels. # Refresh it to the resolved nodata so it doesn't contradict # ``out_attrs['nodata']`` after the resample. if 'nodatavals' in raster.attrs: old_nv = raster.attrs['nodatavals'] try: n_entries = max(1, len(old_nv)) except TypeError: n_entries = 1 out_attrs['nodatavals'] = tuple(nd for _ in range(n_entries)) if target_vertical_crs is not None: # Align with xrspatial.geotiff: attrs['vertical_crs'] holds the # EPSG integer code. The friendly string token is preserved under # attrs['vertical_datum'] so the human-readable name is not lost. # See GH issue #1570. out_attrs['vertical_crs'] = _VERTICAL_DATUM_EPSG.get(target_vertical_crs) out_attrs['vertical_datum'] = target_vertical_crs # Handle multi-band output (3D result from multi-band source) if result_data.ndim == 3: # Find the band dimension name from the source band_dims = [d for d in raster.dims if d not in (ydim, xdim)] band_dim = band_dims[0] if band_dims else 'band' out_dims = [ydim, xdim, band_dim] out_coords = {ydim: y_coords, xdim: x_coords} if band_dim in raster.coords: out_coords[band_dim] = raster.coords[band_dim] # Carry forward non-spatial coords (e.g. scalar 'time' coord). # Skip coords aligned to the rebuilt spatial dims because their # values do not apply to the new grid. for cname, cval in raster.coords.items(): if cname in (ydim, xdim, band_dim): continue if ydim in cval.dims or xdim in cval.dims: continue out_coords[cname] = cval else: out_dims = [ydim, xdim] out_coords = {ydim: y_coords, xdim: x_coords} # Carry forward non-spatial coords (e.g. scalar 'time' coord). # Skip coords aligned to the rebuilt spatial dims because their # values do not apply to the new grid. for cname, cval in raster.coords.items(): if cname in (ydim, xdim): continue if ydim in cval.dims or xdim in cval.dims: continue out_coords[cname] = cval result = xr.DataArray( result_data, dims=out_dims, coords=out_coords, name=name or raster.name, attrs=out_attrs, ) # Preserve the input's dim order so a ``(band, y, x)`` source produces a # ``(band, y, x)`` output (#2182). The internal pipeline always builds the # array as ``(y, x, band)`` for 3-D rasters; transpose back here. if result.ndim == 3 and set(_input_dims) == set(result.dims): if tuple(result.dims) != _input_dims: result = result.transpose(*_input_dims) return result
def _promoted_vertical_dtype(src_dtype): """Return the float dtype to use when applying a geoid shift. The geoid offset is fractional metres, so any integer input has to be promoted before we can add it. Returns ``None`` when the input is already a float dtype (including ``float64``) -- we leave its precision alone. Returns ``np.float32`` for any non-float input; that gives ~0.1 mm resolution in metres at altitudes up to ~10 km, which is well below the accuracy of the geoid grids themselves. """ src_dtype = np.dtype(src_dtype) if np.issubdtype(src_dtype, np.floating): return None return np.float32 def _apply_vertical_shift(data, y_coords, x_coords, src_vcrs, tgt_vcrs, nodata, tgt_crs_wkt=None): """Apply vertical datum shift to reprojected height values. The geoid undulation grid is in geographic (lon/lat) coordinates. If the output CRS is projected, coordinates are inverse-projected to geographic before the geoid lookup. Supported vertical CRS: - 'EGM96', 'EGM2008': orthometric heights (above geoid/MSL) - 'ellipsoidal': heights above WGS84 ellipsoid Backend handling ---------------- The geoid undulation lookup is CPU-only (Numba JIT). To stay correct across backends: - For ``cupy`` arrays, the input is brought to host, shifted, and moved back to GPU. - For ``dask`` arrays, the shift is wrapped in ``map_blocks`` so each chunk's slab is materialised, shifted, and returned in place. - For 3-D ``(y, x, band)`` results, the shift is applied per band because the geoid undulation depends only on horizontal position. Returns ------- (shifted, out_nodata) : tuple ``shifted`` is the height-shifted array, possibly promoted to a float dtype if the input was integer (see ``_promoted_vertical_dtype``). ``out_nodata`` is the nodata sentinel that matches the returned dtype -- ``NaN`` whenever a promotion happened so the caller can rely on NaN semantics in the float output. """ # Direction (sign convention) is the same regardless of backend. geoid_models = [] signs = [] if src_vcrs in ('EGM96', 'EGM2008') and tgt_vcrs == 'ellipsoidal': geoid_models.append(src_vcrs) signs.append(1.0) # H + N = h elif src_vcrs == 'ellipsoidal' and tgt_vcrs in ('EGM96', 'EGM2008'): geoid_models.append(tgt_vcrs) signs.append(-1.0) # h - N = H elif src_vcrs in ('EGM96', 'EGM2008') and tgt_vcrs in ('EGM96', 'EGM2008'): geoid_models.extend([src_vcrs, tgt_vcrs]) signs.extend([1.0, -1.0]) # H1 + N1 - N2 else: return data, nodata x_arr = np.asarray(x_coords, dtype=np.float64) y_arr = np.asarray(y_coords, dtype=np.float64) # Decide the output dtype once, here, so every backend agrees and # the caller can update attrs['nodata'] etc. with the post-shift # sentinel. promoted = _promoted_vertical_dtype(data.dtype) if promoted is None: out_dtype = np.dtype(data.dtype) out_nodata = nodata else: out_dtype = np.dtype(promoted) # NaN is the natural sentinel in a float output. The numpy path # below replaces any cells that matched the input integer # ``nodata`` with NaN before applying the shift. out_nodata = float('nan') # Dask backend: wrap the per-block computation. Each block sees its # own row slab of x/y coords, so we route through map_blocks and # delegate to the numpy path on every chunk. try: import dask.array as da is_dask = isinstance(data, da.Array) except ImportError: is_dask = False if is_dask: shifted = _apply_vertical_shift_dask( data, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) return shifted, out_nodata # CuPy backend: round-trip via host. The geoid lookup is small # relative to the reprojection itself, and the CPU JIT path is # already well-tested. try: import cupy as cp is_cupy = isinstance(data, cp.ndarray) except ImportError: is_cupy = False if is_cupy: host = cp.asnumpy(data) shifted = _apply_vertical_shift_numpy( host, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) return cp.asarray(shifted), out_nodata shifted = _apply_vertical_shift_numpy( np.asarray(data), y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) return shifted, out_nodata def _apply_vertical_shift_numpy(data, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=None): """Apply geoid shift on a numpy array. Handles both 2-D ``(H, W)`` and 3-D ``(H, W, B)`` shapes by looping over band slices; the geoid undulation depends only on horizontal position, so each band sees the same N(y, x) correction. If ``out_dtype`` is given and differs from ``data.dtype`` (the integer-DEM case), the input is cast to that float dtype before the shift and the integer ``nodata`` sentinel is replaced with NaN in the promoted output. When ``out_dtype`` is ``None`` (or matches the input), the behaviour is in-place on a copy of the input -- same as before this signature was added. """ from ._projections import _PARALLEL_KERNEL_LOCK from ._vertical import _interp_geoid_2d, _load_geoid # Determine if we need inverse projection (output CRS is projected) need_inverse = False transformer = None if tgt_crs_wkt is not None: try: from ._crs_utils import _require_pyproj pyproj = _require_pyproj() tgt_crs = pyproj.CRS.from_wkt(tgt_crs_wkt) if not tgt_crs.is_geographic: need_inverse = True geo_crs = pyproj.CRS.from_epsg(4326) transformer = pyproj.Transformer.from_crs( tgt_crs, geo_crs, always_xy=True ) except Exception: pass assert data.ndim in (2, 3), f"expected 2-D or 3-D, got {data.ndim}-D" out_h, out_w = data.shape[:2] is_3d = data.ndim == 3 geoids = [_load_geoid(gm) for gm in geoid_models] # Process in row strips to bound memory in the numpy path; dask chunks # are usually smaller than one strip so this loop runs once per block. # If the caller asked for dtype promotion (integer DEM -> float), # build ``result`` in the new dtype and rewrite the input nodata # sentinel to NaN so the rest of the loop can treat NaN as the # missing-value marker uniformly. promoted = ( out_dtype is not None and np.dtype(out_dtype) != np.dtype(data.dtype) ) if promoted: result = data.astype(out_dtype, copy=True) # Map the source nodata (e.g. -32768 for int16) onto NaN before # shifting so downstream consumers can use NaN semantics. if isinstance(nodata, float) and np.isnan(nodata): pass # already NaN else: result[data == nodata] = np.nan is_nan_nodata = True else: result = data.copy() is_nan_nodata = np.isnan(nodata) if isinstance(nodata, float) else False strip = 128 for r0 in range(0, out_h, strip): r1 = min(r0 + strip, out_h) n_rows = r1 - r0 # Build strip coordinate grid xx_strip = np.tile(x_arr, n_rows).reshape(n_rows, out_w) yy_strip = np.repeat(y_arr[r0:r1], out_w).reshape(n_rows, out_w) # Inverse project if needed if need_inverse and transformer is not None: lon_s, lat_s = transformer.transform(xx_strip.ravel(), yy_strip.ravel()) xx_strip = np.asarray(lon_s, dtype=np.float64).reshape(n_rows, out_w) yy_strip = np.asarray(lat_s, dtype=np.float64).reshape(n_rows, out_w) # Guard against non-finite output coords (projection singularities, # antimeridian, polar regions). Hand NaN to the JIT batch so the # longitude wrap loop in _interp_geoid_point does not see inf and # spin forever. finite_coord = np.isfinite(xx_strip) & np.isfinite(yy_strip) if not finite_coord.all(): xx_strip = np.where(finite_coord, xx_strip, np.nan) yy_strip = np.where(finite_coord, yy_strip, np.nan) # Compute the total shift N_total for this strip once, regardless # of how many geoid models contribute. N_total = np.zeros((n_rows, out_w), dtype=np.float64) for (grid_data, g_left, g_top, g_rx, g_ry, g_h, g_w), sign in zip(geoids, signs): N_strip = np.empty((n_rows, out_w), dtype=np.float64) # This runs per-chunk under dask's threaded scheduler (via # _apply_vertical_shift_dask), and the kernel is parallel=True, # so the launch must be serialized (#3141). with _PARALLEL_KERNEL_LOCK: _interp_geoid_2d(xx_strip, yy_strip, N_strip, grid_data, g_left, g_top, g_rx, g_ry, g_h, g_w) N_total += sign * N_strip # Apply to each band slice. For 2-D this loop runs once. n_bands = data.shape[2] if is_3d else 1 for b in range(n_bands): if is_3d: strip_data = result[r0:r1, :, b] else: strip_data = result[r0:r1] if is_nan_nodata: is_valid = np.isfinite(strip_data) else: is_valid = strip_data != nodata is_valid = is_valid & finite_coord strip_data[is_valid] += N_total[is_valid] return result def _apply_vertical_shift_dask(data, y_arr, x_arr, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=None): """Dask-backed geoid shift via ``map_blocks``. Each block receives only its row/column slab of the input. The block function recomputes per-block y/x slices from ``block_info`` and delegates to the numpy path so all backends share one implementation. ``out_dtype`` is the (possibly promoted) dtype that ``_apply_vertical_shift`` has decided for the whole array, and is forwarded into every per-block call so the chunks return the same promoted dtype the dask graph metadata advertises. """ import dask.array as da if out_dtype is None: out_dtype = data.dtype # Note: blocks reaching this path are assumed to be numpy-backed. # dask-of-cupy is intentionally unreached today because # ``_reproject_dask(is_cupy=True)`` collapses to a numpy-backed dask # array upstream. If that ever changes, _block would need to detect # cupy chunks and host-bounce per chunk. def _block(block, block_info): info = block_info[0] (r0, r1), (c0, c1) = info['array-location'][:2] y_slab = y_arr[r0:r1] x_slab = x_arr[c0:c1] return _apply_vertical_shift_numpy( block, y_slab, x_slab, geoid_models, signs, nodata, tgt_crs_wkt, out_dtype=out_dtype, ) # ``meta`` hardcodes a numpy template to match the assumption above # that incoming chunks are numpy-backed. Revisit if dask-of-cupy is # ever plumbed through. return da.map_blocks(_block, data, dtype=out_dtype, meta=np.array((), dtype=out_dtype), **_dask_task_name_kwargs('xrspatial.reproject_vertical_shift')) def _reproject_inmemory_numpy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, x_desc=False, band_nodata=None, ): """Single-chunk numpy reproject.""" return _reproject_chunk_numpy( raster.values, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) def _reproject_inmemory_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, x_desc=False, band_nodata=None, ): """Single-chunk cupy reproject.""" return _reproject_chunk_cupy( raster.data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) def _parse_max_memory(max_memory): """Parse max_memory parameter to bytes. Accepts int, '4GB', '512MB'.""" if max_memory is None: return 1024 * 1024 * 1024 # 1GB default if isinstance(max_memory, (int, float)): return int(max_memory) s = str(max_memory).strip().upper() for suffix, factor in [('TB', 1024**4), ('GB', 1024**3), ('MB', 1024**2), ('KB', 1024)]: if s.endswith(suffix): return int(float(s[:-len(suffix)]) * factor) return int(s) def _process_tile_batch(batch, source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, resampling, nodata, precision, max_memory_bytes, tile_mem, x_desc=False, band_nodata=None): """Process a batch of tiles within a single worker. Uses ThreadPoolExecutor for intra-worker parallelism (Numba releases the GIL). Memory bounded by max_memory_bytes. The numba coordinate-transform fast path uses parallel=True kernels, which must not run concurrently from multiple host threads (the workqueue threading layer aborts the process, #3141). try_numba_transform serializes those launches behind _projections._PARALLEL_KERNEL_LOCK; resampling stays concurrent. Returns list of (row_offset, col_offset, tile_data) tuples. """ max_concurrent = max(1, max_memory_bytes // max(tile_mem, 1)) def _do_one(job): _, _, rchunk, cchunk, cb = job return _reproject_chunk_numpy( source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, cb, (rchunk, cchunk), resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) results = [] if max_concurrent >= 2 and len(batch) > 1: import os from concurrent.futures import ThreadPoolExecutor n_threads = min(max_concurrent, len(batch), os.cpu_count() or 4) with ThreadPoolExecutor(max_workers=n_threads) as pool: for sub_start in range(0, len(batch), n_threads): sub = batch[sub_start:sub_start + n_threads] tiles = list(pool.map(_do_one, sub)) for job, tile in zip(sub, tiles): ro, co, rchunk, cchunk, _ = job results.append((ro, co, tile)) del tiles else: for job in batch: ro, co, rchunk, cchunk, _ = job tile = _do_one(job) results.append((ro, co, tile)) del tile return results def _reproject_streaming( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, tile_size, max_memory_bytes, x_desc=False, band_nodata=None, ): """Streaming reproject for datasets too large for dask's graph. Two modes: 1. **Local** (no dask.distributed): ThreadPoolExecutor within one process, bounded by max_memory. 2. **Distributed** (dask.distributed active): creates a dask.bag with one partition per worker, each partition processes its tile batch using threads. Graph size: O(n_workers), not O(n_tiles). Memory usage per worker: bounded by max_memory. Output dtype follows the same rule as the dask backends: integer sources round-trip back to their original dtype (the per-tile worker casts tiles back after clamping), floats return float64 (#3093). """ if isinstance(tile_size, int): tile_size = (tile_size, tile_size) # Match the dask backends: integer sources round-trip back to their # original dtype after clamping (the per-tile worker already casts # tiles back); floats stay float64. Without this, the streaming path # silently promoted integer inputs to float64 while every other # backend preserved the source dtype (#3093). src_dtype = np.dtype(raster.dtype) if np.issubdtype(src_dtype, np.integer): out_dtype = src_dtype else: out_dtype = np.dtype(np.float64) # 3-D (y, x, band) sources produce 3-D tiles, so the assembled output # needs the band axis too (#3093). if raster.data.ndim == 3: result_shape = (*out_shape, raster.data.shape[2]) else: result_shape = out_shape row_chunks, col_chunks = _compute_chunk_layout(out_shape, tile_size) tile_mem = tile_size[0] * tile_size[1] * 8 * 4 # ~4 arrays per tile # Build tile job list jobs = [] row_offset = 0 for rchunk in row_chunks: col_offset = 0 for cchunk in col_chunks: cb = _chunk_bounds( out_bounds, out_shape, row_offset, row_offset + rchunk, col_offset, col_offset + cchunk, ) jobs.append((row_offset, col_offset, rchunk, cchunk, cb)) col_offset += cchunk row_offset += rchunk # Check if dask.distributed is active _use_distributed = False try: from dask.distributed import get_client client = get_client() n_distributed_workers = len(client.scheduler_info()['workers']) if n_distributed_workers > 0: _use_distributed = True except (ImportError, ValueError): pass if _use_distributed and len(jobs) > n_distributed_workers: # Distributed: partition tiles across workers via dask.bag import dask.bag as db # Split jobs into N partitions (one per worker) n_parts = min(n_distributed_workers, len(jobs)) batch_size = math.ceil(len(jobs) / n_parts) batches = [jobs[i:i + batch_size] for i in range(0, len(jobs), batch_size)] # Create bag and map the batch processor bag = db.from_sequence(batches, npartitions=len(batches)) results_bag = bag.map( _process_tile_batch, source_data=raster.data, src_bounds=src_bounds, src_shape=src_shape, y_desc=y_desc, src_wkt=src_wkt, tgt_wkt=tgt_wkt, resampling=resampling, nodata=nodata, precision=precision, max_memory_bytes=max_memory_bytes, tile_mem=tile_mem, x_desc=x_desc, band_nodata=band_nodata, ) # Compute all partitions and assemble result result = np.full(result_shape, nodata, dtype=out_dtype) for batch_results in results_bag.compute(): for ro, co, tile in batch_results: result[ro:ro + tile.shape[0], co:co + tile.shape[1]] = tile return result # Local: ThreadPoolExecutor within one process result = np.full(result_shape, nodata, dtype=out_dtype) batch_results = _process_tile_batch( jobs, raster.data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, resampling, nodata, precision, max_memory_bytes, tile_mem, x_desc=x_desc, band_nodata=band_nodata, ) for ro, co, tile in batch_results: result[ro:ro + tile.shape[0], co:co + tile.shape[1]] = tile return result def _reproject_dask_cupy( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size, x_desc=False, band_nodata=None, ): """Dask+CuPy backend: process output chunks on GPU. Two modes based on available GPU memory: **Fast path** (output fits in GPU VRAM): pre-allocates the full output on GPU and fills it chunk-by-chunk. ~22x faster than the map_blocks path because CRS/transformer objects are created once and CUDA kernels run with minimal launch overhead. **Chunked fallback** (output exceeds GPU VRAM): delegates to ``_reproject_dask(is_cupy=True)`` which uses ``map_blocks`` and processes one chunk at a time with O(chunk_size) GPU memory. """ import cupy as cp from ._crs_utils import _crs_from_wkt src_crs = _crs_from_wkt(src_wkt) tgt_crs = _crs_from_wkt(tgt_wkt) # Use larger chunks for GPU to amortize kernel launch overhead gpu_chunk = chunk_size or 2048 if isinstance(gpu_chunk, int): gpu_chunk = (gpu_chunk, gpu_chunk) row_chunks, col_chunks = _compute_chunk_layout(out_shape, gpu_chunk) out_h, out_w = out_shape src_left, src_bottom, src_right, src_top = src_bounds src_h, src_w = src_shape src_res_x = (src_right - src_left) / src_w src_res_y = (src_top - src_bottom) / src_h # 3-D source: the fast path's inline loop assumes 2-D windows. # Delegate to the map_blocks path which handles 3-D via # _reproject_chunk_cupy's per-band loop (#2027). if raster.data.ndim == 3: return _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size or 2048, True, # is_cupy=True x_desc=x_desc, band_nodata=band_nodata, ) # Memory check: if the full output doesn't fit in GPU memory, # fall back to the map_blocks path which is O(chunk_size) memory. estimated = out_shape[0] * out_shape[1] * 8 # float64 try: free_gpu, _total = cp.cuda.Device().mem_info fits_in_gpu = estimated < 0.5 * free_gpu except (AttributeError, RuntimeError): fits_in_gpu = False if not fits_in_gpu: import warnings warnings.warn( f"Output ({estimated / 1e9:.1f} GB) exceeds GPU memory; " f"falling back to chunked map_blocks path.", stacklevel=3, ) return _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size or 2048, True, # is_cupy=True x_desc=x_desc, band_nodata=band_nodata, ) # Match the dask+numpy and chunked dask+cupy paths: integer sources # round-trip back to their original dtype after clamping; floats stay # float64. Without this, the eager dask+cupy fast path silently # promoted int16/uint8/etc. inputs to float64 while the other # backends preserved the source dtype (#2505). src_dtype = np.dtype(raster.dtype) if np.issubdtype(src_dtype, np.integer): out_dtype = src_dtype else: out_dtype = np.dtype(np.float64) result = cp.full(out_shape, nodata, dtype=out_dtype) row_offset = 0 for i, rchunk in enumerate(row_chunks): col_offset = 0 for j, cchunk in enumerate(col_chunks): cb = _chunk_bounds( out_bounds, out_shape, row_offset, row_offset + rchunk, col_offset, col_offset + cchunk, ) chunk_shape = (rchunk, cchunk) # CUDA coordinate transform (reuses cached CRS objects). # precision == 0 forces the exact pyproj path, so skip CUDA. cuda_coords = None if precision != 0: try: from ._projections_cuda import try_cuda_transform cuda_coords = try_cuda_transform( src_crs, tgt_crs, cb, chunk_shape, ) except (ImportError, ModuleNotFoundError): cuda_coords = None if cuda_coords is not None: src_y, src_x = cuda_coords if x_desc: src_col_px = (src_right - src_x) / src_res_x - 0.5 else: src_col_px = (src_x - src_left) / src_res_x - 0.5 if y_desc: src_row_px = (src_top - src_y) / src_res_y - 0.5 else: src_row_px = (src_y - src_bottom) / src_res_y - 0.5 # Batch the four reductions into a single device-to-host # transfer instead of four separate synchronous .get() calls. mins_maxes = cp.stack([ cp.nanmin(src_row_px), cp.nanmax(src_row_px), cp.nanmin(src_col_px), cp.nanmax(src_col_px), ]) r_min_val, r_max_val, c_min_val, c_max_val = ( float(v) for v in mins_maxes.get() ) if not (np.isfinite(r_min_val) and np.isfinite(r_max_val) and np.isfinite(c_min_val) and np.isfinite(c_max_val)): col_offset += cchunk continue r_min = int(np.floor(r_min_val)) - 2 r_max = int(np.ceil(r_max_val)) + 3 c_min = int(np.floor(c_min_val)) - 2 c_max = int(np.ceil(c_max_val)) + 3 else: # CPU fallback for this chunk from ._crs_utils import _require_pyproj pyproj = _require_pyproj() transformer = pyproj.Transformer.from_crs( tgt_crs, src_crs, always_xy=True ) src_y, src_x = _transform_coords( transformer, cb, chunk_shape, precision, src_crs=src_crs, tgt_crs=tgt_crs, ) if x_desc: src_col_px = (src_right - src_x) / src_res_x - 0.5 else: src_col_px = (src_x - src_left) / src_res_x - 0.5 if y_desc: src_row_px = (src_top - src_y) / src_res_y - 0.5 else: src_row_px = (src_y - src_bottom) / src_res_y - 0.5 r_min = np.nanmin(src_row_px) r_max = np.nanmax(src_row_px) c_min = np.nanmin(src_col_px) c_max = np.nanmax(src_col_px) if not np.isfinite(r_min) or not np.isfinite(r_max): col_offset += cchunk continue if not np.isfinite(c_min) or not np.isfinite(c_max): col_offset += cchunk continue r_min = int(np.floor(r_min)) - 2 r_max = int(np.ceil(r_max)) + 3 c_min = int(np.floor(c_min)) - 2 c_max = int(np.ceil(c_max)) + 3 # Check overlap if r_min >= src_h or r_max <= 0 or c_min >= src_w or c_max <= 0: col_offset += cchunk continue r_min_clip = max(0, r_min) r_max_clip = min(src_h, r_max) c_min_clip = max(0, c_min) c_max_clip = min(src_w, c_max) # Guard: cap source window to prevent GPU OOM (matches numpy path) _MAX_WINDOW_PIXELS = 64 * 1024 * 1024 # 64 Mpix win_pixels = (r_max_clip - r_min_clip) * (c_max_clip - c_min_clip) if win_pixels > _MAX_WINDOW_PIXELS: col_offset += cchunk continue # Fetch only the needed source window from dask window = raster.data[r_min_clip:r_max_clip, c_min_clip:c_max_clip] if hasattr(window, 'compute'): window = window.compute() if not isinstance(window, cp.ndarray): window = cp.asarray(window) window = window.astype(cp.float64) if not np.isnan(nodata): window[window == nodata] = cp.nan local_row = src_row_px - r_min_clip local_col = src_col_px - c_min_clip # Always use the native CUDA kernels (numpy parity). The # cupyx.scipy.ndimage.map_coordinates fallback diverged from # numpy at chunk edges and for cubic resampling (#2620). # local_row/local_col may be numpy here (pyproj fallback); # _resample_cupy_native transfers them to the GPU. chunk_data = _resample_cupy_native( window, local_row, local_col, resampling=resampling, nodata=nodata, ) # Clamp + cast back for integer source dtypes so this fast # path returns the same dtype as the other backends (#2505). # Matches the per-chunk cast in _reproject_chunk_cupy. if np.issubdtype(out_dtype, np.integer): info = np.iinfo(out_dtype) chunk_data = cp.clip( cp.round(chunk_data), info.min, info.max, ).astype(out_dtype) result[row_offset:row_offset + rchunk, col_offset:col_offset + cchunk] = chunk_data col_offset += cchunk row_offset += rchunk return result def _finite_pair_bbox(tx, ty): """Bounding box of (tx, ty) pairs where both coordinates are finite. The x and y coordinates must be filtered together: a transform can send some probe points to NaN/inf, and dropping finite x and finite y independently would mix coordinates from different points into one box. Returns ``(left, bottom, right, top)`` or ``None`` when no pair is finite in both coordinates. """ tx = np.asarray(tx, dtype=np.float64) ty = np.asarray(ty, dtype=np.float64) mask = np.isfinite(tx) & np.isfinite(ty) if not mask.any(): return None tx = tx[mask] ty = ty[mask] return (float(tx.min()), float(ty.min()), float(tx.max()), float(ty.max())) def _source_footprint_in_target(src_bounds, src_wkt, tgt_wkt): """Compute approximate bounding box of source raster in target CRS.""" try: from ._crs_utils import _crs_from_wkt, _resolve_crs try: src_crs = _crs_from_wkt(src_wkt) except Exception: src_crs = _resolve_crs(src_wkt) try: tgt_crs = _crs_from_wkt(tgt_wkt) except Exception: tgt_crs = _resolve_crs(tgt_wkt) except Exception: return None sl, sb, sr, st = src_bounds mx = (sl + sr) / 2 my = (sb + st) / 2 xs = np.array([sl, mx, sr, sl, mx, sr, sl, mx, sr, sl, sr, mx]) ys = np.array([sb, sb, sb, my, my, my, st, st, st, mx, mx, sb]) try: from ._projections import transform_points result = transform_points(src_crs, tgt_crs, xs, ys) if result is not None: tx, ty = result return _finite_pair_bbox(tx, ty) except (ImportError, ModuleNotFoundError): pass try: from ._crs_utils import _require_pyproj pyproj = _require_pyproj() transformer = pyproj.Transformer.from_crs(src_crs, tgt_crs, always_xy=True) tx, ty = transformer.transform(xs.tolist(), ys.tolist()) return _finite_pair_bbox(tx, ty) except Exception: return None def _bounds_overlap(a, b): """Return True if bounding boxes *a* and *b* overlap.""" return a[0] < b[2] and a[2] > b[0] and a[1] < b[3] and a[3] > b[1] def _reproject_block_adapter( block, block_info, source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, is_cupy, src_footprint_tgt, n_bands=None, x_desc=False, band_nodata=None, ): """``map_blocks`` adapter for reprojection. Derives chunk bounds from *block_info* and delegates to the per-chunk worker. For 3-D sources the template carries the band axis, so each block is ``(rh, rw, n_bands)``. The adapter strips the trailing band axis when computing 2-D chunk bounds and the per-chunk worker returns a 3-D result that fits the template (#2027). """ info = block_info[0] # 3-D template -> array-location is 3 entries; spatial dims are the # first two. Band dim spans the full axis (single chunk). spatial_loc = info['array-location'][:2] (row_start, row_end), (col_start, col_end) = spatial_loc chunk_shape = (row_end - row_start, col_end - col_start) cb = _chunk_bounds(out_bounds, out_shape, row_start, row_end, col_start, col_end) is_3d = n_bands is not None # Skip chunks that don't overlap the source footprint. The fill dtype # must match the template: integer sources advertise their own dtype, # so a float64 fill here would promote the assembled output (#3096). if src_footprint_tgt is not None and not _bounds_overlap(cb, src_footprint_tgt): if np.issubdtype(source_data.dtype, np.integer): empty_dtype = source_data.dtype else: empty_dtype = np.float64 empty_shape = (*chunk_shape, n_bands) if is_3d else chunk_shape if is_cupy: # Match the data chunks (and the 3-D branch): dask+cupy # blocks should be cupy arrays even when empty. import cupy as cp return cp.full(empty_shape, nodata, dtype=empty_dtype) return np.full(empty_shape, nodata, dtype=empty_dtype) chunk_fn = _reproject_chunk_cupy if is_cupy else _reproject_chunk_numpy return chunk_fn( source_data, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, cb, chunk_shape, resampling, nodata, precision, source_x_desc=x_desc, band_nodata=band_nodata, ) def _reproject_dask( raster, src_bounds, src_shape, y_desc, src_wkt, tgt_wkt, out_bounds, out_shape, resampling, nodata, precision, chunk_size, is_cupy, x_desc=False, band_nodata=None, ): """Dask+NumPy backend: ``map_blocks`` over a template array. Uses a single ``blockwise`` layer in the HighLevelGraph instead of O(N) ``dask.delayed`` nodes, keeping graph metadata O(1). The source dask array is bound to the adapter via ``functools.partial`` rather than passed as a ``map_blocks`` kwarg. This prevents dask from adding the full source as a dependency of every output block (which would cause a MemoryError on distributed schedulers when the source exceeds the worker memory limit). For 3-D sources with shape ``(H, W, n_bands)`` the template is built as ``(out_H, out_W, n_bands)`` so the lazy metadata matches the actual chunk output shape (#2027). Without this, the lazy DataArray advertised 2-D shape while the underlying chunks produced 3-D arrays, causing a ``ValueError: replacement data has shape ...`` crash on ``.compute()``. """ import functools import dask.array as da row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size) # Precompute source footprint in target CRS for empty-chunk skipping src_footprint_tgt = _source_footprint_in_target( src_bounds, src_wkt, tgt_wkt ) # Detect 3-D source: chunks will return (rh, rw, n_bands) so the # template must carry the band axis through to the lazy DataArray. is_3d = raster.data.ndim == 3 n_bands = raster.data.shape[2] if is_3d else None # Bind the source dask array and all scalar params via partial so # map_blocks doesn't detect them as dask Array kwargs (which would # add the full source as a dependency of every output block). bound_adapter = functools.partial( _reproject_block_adapter, source_data=raster.data, src_bounds=src_bounds, src_shape=src_shape, y_desc=y_desc, src_wkt=src_wkt, tgt_wkt=tgt_wkt, out_bounds=out_bounds, out_shape=out_shape, resampling=resampling, nodata=nodata, precision=precision, is_cupy=is_cupy, src_footprint_tgt=src_footprint_tgt, n_bands=n_bands, x_desc=x_desc, band_nodata=band_nodata, ) # Pick the template dtype to match the eager path: integer sources # round-trip back to their original dtype after clamping; floats stay # float64. Without this, dask claims float64 meta but the chunks # actually return the integer dtype, producing inconsistent output. src_dtype = np.dtype(raster.dtype) if np.issubdtype(src_dtype, np.integer): out_dtype = src_dtype else: out_dtype = np.dtype(np.float64) if is_3d: # Band axis is one chunk in the template regardless of how the # source dask array is chunked along its band axis. The per-block # worker reads all bands together (via a 2-D y/x slice that # rejoins band chunks on compute) and emits the full band stack # for its (rh, rw) tile, so multi-chunk output along bands would # never get filled. template = da.empty( (*out_shape, n_bands), dtype=out_dtype, chunks=(row_chunks, col_chunks, (n_bands,)), ) else: template = da.empty( out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks) ) return da.map_blocks( bound_adapter, template, dtype=out_dtype, meta=np.array((), dtype=out_dtype), **_dask_task_name_kwargs('xrspatial.reproject'), ) # --------------------------------------------------------------------------- # merge() # --------------------------------------------------------------------------- def _merge_inputs_to_host(rasters): """Bring cupy-backed merge inputs to the host. The merge pipeline is numpy-based end to end (``_place_same_crs``, ``_reproject_chunk_numpy``, ``_merge_arrays_numpy``), so cupy and dask+cupy inputs are converted before any pixel data is touched. Returns ``(any_cupy, host_rasters)``; the input list is returned unchanged when no raster is GPU-backed. Dask-of-cupy inputs stay lazy: each chunk is converted via ``map_blocks`` at compute time. """ try: import cupy as cp except ImportError: return False, rasters from ..utils import has_dask_array, is_dask_cupy # is_dask_cupy dereferences dask.array.Array, so it must only run # when dask is importable (cupy without dask is a valid install). _dask_ok = has_dask_array() any_cupy = False host = [] for r in rasters: if isinstance(r.data, cp.ndarray): any_cupy = True host.append(r.copy(data=cp.asnumpy(r.data))) elif _dask_ok and is_dask_cupy(r): any_cupy = True host.append(r.copy(data=r.data.map_blocks( cp.asnumpy, dtype=r.dtype, meta=np.array((), dtype=r.dtype), ))) else: host.append(r) return any_cupy, host
[docs] def merge( rasters, *, target_crs=None, resolution=None, bounds=None, resampling='bilinear', nodata=None, strategy='first', chunk_size=None, transform_precision=16, name=None, bounds_policy="auto", ): """Merge multiple rasters into a single mosaic. Each input is reprojected to the target CRS (if needed) and placed into a unified output grid. Overlapping regions are resolved using the selected *strategy*. Accepts numpy, cupy, dask+numpy, and dask+cupy inputs. The merge itself runs on the CPU: cupy-backed rasters are copied to the host on entry and the mosaic is returned as a cupy (or dask+cupy) array when any input was GPU-backed. Parameters ---------- rasters : list of xr.DataArray Input rasters to merge. target_crs : optional Target CRS. Defaults to the CRS of the first raster. resolution : float or (float, float) or None Output resolution in target CRS units. bounds : (left, bottom, right, top) or None Explicit output extent. resampling : str Interpolation method: 'nearest', 'bilinear', 'cubic'. nodata : float or None Nodata value for the output. Auto-detected from the first raster if None. When every input shares one integer dtype, an explicit value that does not fit that dtype range raises ``ValueError``, and an auto-detected NaN falls back to ``dtype.min`` for signed or ``dtype.max`` for unsigned -- the same rules ``reproject`` applies (#2185/#2572). strategy : str Merge strategy: 'first', 'last', 'mean', 'max', 'min'. chunk_size : int or (int, int) or None Output chunk size for dask. If None, defaults to 512 for the standard dask path and 2048 for the in-memory streaming and dask+cupy paths (chosen to amortize kernel launch overhead). transform_precision : int Control-grid subdivisions for the coordinate transform (default 16). Higher values increase accuracy at the cost of more pyproj calls. Set to 0 for exact per-pixel transforms matching GDAL/rasterio. name : str or None Name for the output DataArray. bounds_policy : {"auto", "raw", "clamp", "percentile"}, default "auto" How to derive the unified output extent from the input rasters when ``bounds`` is not supplied. See :func:`reproject` for the full description of each option. Ignored when ``bounds`` is given. Returns ------- xr.DataArray When every input shares the same integer dtype, the mosaic is cast back to that dtype (values rounded and clipped, matching ``reproject``'s integer round-trip). Mixed input dtypes and float inputs produce a float64 mosaic. The output y coordinate is always emitted in descending order (top-down, north-up) regardless of the input direction. This matches the standard raster convention and the output of common GIS libraries. Non-spatial coords from the first raster (such as a scalar ``time`` coord) are carried through to the output. Coords aligned to the spatial dims are dropped because their values do not apply to the merged grid. Notes ----- There is no streaming in-memory path; for very large output mosaics, pass dask-backed inputs (or rely on the automatic promotion to the dask path) so that each output chunk is computed independently. Examples -------- >>> import xarray as xr >>> import numpy as np >>> from xrspatial.reproject import merge >>> tile_a = xr.DataArray( ... np.full((32, 32), 1.0), dims=['y', 'x'], ... coords={'y': np.linspace(50, 45, 32), ... 'x': np.linspace(-5, 0, 32)}, ... attrs={'crs': 'EPSG:4326'}, ... ) >>> tile_b = xr.DataArray( ... np.full((32, 32), 2.0), dims=['y', 'x'], ... coords={'y': np.linspace(50, 45, 32), ... 'x': np.linspace(0, 5, 32)}, ... attrs={'crs': 'EPSG:4326'}, ... ) >>> mosaic = merge([tile_a, tile_b], resolution=0.5) >>> mosaic.shape[1] >= 32 True """ if not rasters: raise ValueError("merge(): rasters list must not be empty") for i, r in enumerate(rasters): # merge() only supports 2-D rasters: the merge strategies, same-CRS # placement, and output DataArray construction all assume (y, x). # The 3-D (y, x, band) path was never implemented end-to-end and # crashed at DataArray construction with a dims-vs-shape mismatch # (#2027). Reject 3-D up front so callers get a clear error. _validate_raster(r, func_name='merge', name=f'rasters[{i}]', ndim=(2,)) # Each input must have strictly monotonic, regularly spaced spatial # coords. _source_bounds() infers pixel size from only the first # two samples, so irregular inputs would otherwise silently produce # wrong georeferencing (#2184). _validate_source_coords(r, f'merge[rasters[{i}]]') _validate_grid_params( resolution=resolution, bounds=bounds, width=None, height=None, transform_precision=transform_precision, func_name='merge', ) _validate_resampling(resampling) _validate_strategy(strategy) from ._grid import _validate_bounds_policy _validate_bounds_policy(bounds_policy, func_name='merge') # GPU inputs: the merge pipeline is numpy-based, so cupy and # dask+cupy rasters are merged on the host and the mosaic is moved # back to the GPU at the end. Same round-trip _apply_vertical_shift # uses for its CPU-only geoid lookup. Without this, cupy inputs # crashed at `.values` with xarray's implicit-conversion TypeError # (#3095). _any_cupy, rasters = _merge_inputs_to_host(rasters) # Resolve target CRS tgt_crs = _resolve_crs(target_crs) if tgt_crs is None: tgt_crs = _detect_source_crs(rasters[0]) if tgt_crs is None: raise ValueError( "Could not detect target CRS. Pass target_crs explicitly." ) # Output dtype follows the reproject() convention (#3262): when every # input shares one integer dtype, the mosaic is cast back to that # dtype after the float64 merge; mixed dtypes and float inputs return # float64. Resolved before nodata detection so the sentinel gets the # same dtype-aware treatment as reproject() (#2185/#2572): NaN is # swapped for a representable sentinel and an explicit out-of-range # nodata raises. in_dtypes = {np.dtype(r.dtype) for r in rasters} if len(in_dtypes) == 1 and np.issubdtype(next(iter(in_dtypes)), np.integer): out_dtype = next(iter(in_dtypes)) else: out_dtype = np.dtype(np.float64) # Detect output nodata (the sentinel the user asked for). The merge # runs in float64 with NaN as the canonical sentinel, so NaN is fine # when the user didn't supply one and the output dtype is float; for # integer output the dtype hint resolves a representable sentinel. nd = _detect_nodata(rasters[0], nodata, dtype=out_dtype) # Gather source info for each raster raster_infos = [] for r in rasters: src_crs = _detect_source_crs(r) if src_crs is None: raise ValueError( f"Could not detect CRS for raster '{r.name}'. " "Ensure all rasters have CRS metadata." ) sb = _source_bounds(r) r_ydim, r_xdim = _find_spatial_dims(r) ss = (r.sizes[r_ydim], r.sizes[r_xdim]) yd = _is_y_descending(r) xd = _is_x_descending(r) # Per-raster input nodata sentinel. Detected independently of the # user-supplied output nodata so that mixed-sentinel inputs are # canonicalized correctly during merge. Pass the raster dtype so # integer sources get an integer-compatible sentinel rather than # NaN -- the same fix as the reproject() path for #2185. r_nd = _detect_nodata(r, None, dtype=r.dtype) raster_infos.append({ 'raster': r, 'src_crs': src_crs, 'src_bounds': sb, 'src_shape': ss, 'y_desc': yd, 'x_desc': xd, 'src_wkt': src_crs.to_wkt(), 'raster_nodata': r_nd, }) # Compute unified output grid if bounds is None: # Union of all raster bounds in target CRS. # Capture bounds_policy warnings during the per-input gather and # emit a single deduplicated message afterwards. Otherwise a # mosaic of N near-antimeridian rasters yields N identical # warnings, which is noise the caller cannot act on individually. import warnings as _warnings all_bounds = [] with _warnings.catch_warnings(record=True) as _caught: _warnings.simplefilter('always', UserWarning) for info in raster_infos: grid = _compute_output_grid( info['src_bounds'], info['src_shape'], info['src_crs'], tgt_crs, resolution=resolution, bounds_policy=bounds_policy, ) all_bounds.append(grid['bounds']) _policy_msgs = [ str(w.message) for w in _caught if issubclass(w.category, UserWarning) and 'bounds_policy' in str(w.message) ] if _policy_msgs: _unique = list(dict.fromkeys(_policy_msgs)) _summary = ( f"merge: bounds_policy={bounds_policy!r} altered the " f"projected extent for {len(_policy_msgs)} input " f"raster(s); {len(_unique)} unique trigger(s). " f"First trigger: {_unique[0]}" ) _warnings.warn(_summary, UserWarning, stacklevel=2) # Re-emit non-bounds_policy warnings that we captured. for _w in _caught: if 'bounds_policy' not in str(_w.message): _warnings.warn_explicit( _w.message, _w.category, _w.filename, _w.lineno, ) left = min(b[0] for b in all_bounds) bottom = min(b[1] for b in all_bounds) right = max(b[2] for b in all_bounds) top = max(b[3] for b in all_bounds) merged_bounds = (left, bottom, right, top) else: merged_bounds = bounds # Detect dask inputs up front. A dask-backed merge runs through the # lazy _merge_dask path and never materializes the full output, so its # output-size guard must be skipped. The auto-promote-to-dask decision # below also needs the output shape, which only exists after the grid # is computed -- so compute the grid with the guard disabled here and # re-apply it only when the final path is the in-memory merge. from ..utils import has_dask_array any_dask = False if has_dask_array(): import dask.array as _da any_dask = any(isinstance(r.data, _da.Array) for r in rasters) # Use first raster's info for resolution estimation if needed info0 = raster_infos[0] grid = _compute_output_grid( info0['src_bounds'], info0['src_shape'], info0['src_crs'], tgt_crs, resolution=resolution, bounds=merged_bounds, bounds_policy=bounds_policy, lazy_output=True, ) out_bounds = grid['bounds'] out_shape = grid['shape'] tgt_wkt = tgt_crs.to_wkt() # Auto-promote to dask path if output would be too large for in-memory merge if not any_dask: out_nbytes = out_shape[0] * out_shape[1] * 8 * len(rasters) # float64 per tile if out_nbytes > _MERGE_OOM_THRESHOLD: any_dask = True if not any_dask: # The final path is the in-memory merge, which allocates the whole # output array. Re-apply the output-size guard that was skipped # during grid computation so a genuinely in-memory merge over the # pixel limit still raises (the guard skip only applies to the # lazy dask path). _MAX_OUTPUT_PIXELS is imported from _grid so the # limit stays in sync with _compute_output_grid's own check. if out_shape[0] * out_shape[1] > _MAX_OUTPUT_PIXELS: raise ValueError( f"Computed output grid is too large ({out_shape[1]} x " f"{out_shape[0]} = {out_shape[0] * out_shape[1]:,} pixels, " f"limit is {_MAX_OUTPUT_PIXELS:,}). Increase the resolution " f"parameter or reduce the output extent." ) if any_dask: result_data = _merge_dask( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nd, strategy, chunk_size, transform_precision, out_dtype=out_dtype, ) else: result_data = _merge_inmemory( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nd, strategy, transform_precision, out_dtype=out_dtype, ) y_coords, x_coords = _make_output_coords(out_bounds, out_shape) ydim, xdim = _find_spatial_dims(rasters[0]) out_coords = {ydim: y_coords, xdim: x_coords} # Carry forward non-spatial coords from the first raster (e.g. scalar # 'time' coord). Skip coords aligned to the rebuilt spatial dims # because their values do not apply to the new grid. for cname, cval in rasters[0].coords.items(): if cname in (ydim, xdim): continue if ydim in cval.dims or xdim in cval.dims: continue out_coords[cname] = cval # Carry the first raster's attrs forward (matches the default # strategy='first'). Drop the duplicate `crs_wkt` and re-emit # `transform` and `res` for the new output grid (rasterio 6-tuple). out_left, _out_bottom, _out_right, out_top = grid['bounds'] out_res_x = grid['res_x'] out_res_y = grid['res_y'] out_attrs = {**rasters[0].attrs} out_attrs.pop('crs_wkt', None) out_attrs['crs'] = tgt_wkt out_attrs['nodata'] = nd out_attrs['res'] = (out_res_x, out_res_y) out_attrs['transform'] = ( out_res_x, 0.0, out_left, 0.0, -out_res_y, out_top, ) # If the first raster used `_FillValue`, propagate the resolved # nodata to both keys for consistent round-trip. Leave absent # otherwise. if '_FillValue' in rasters[0].attrs: out_attrs['_FillValue'] = nd # Same treatment for `nodatavals` (rasterio convention). if 'nodatavals' in rasters[0].attrs: old_nv = rasters[0].attrs['nodatavals'] try: n_entries = max(1, len(old_nv)) except TypeError: n_entries = 1 out_attrs['nodatavals'] = tuple(nd for _ in range(n_entries)) # Move the mosaic back to the GPU when any input was cupy-backed. # An eager merge returns a cupy array; a lazy merge keeps the dask # graph and converts each chunk on compute. if _any_cupy: import cupy as cp if any_dask: result_data = result_data.map_blocks( cp.asarray, dtype=result_data.dtype, meta=cp.array((), dtype=result_data.dtype), ) else: result_data = cp.asarray(result_data) result = xr.DataArray( result_data, dims=[ydim, xdim], coords=out_coords, name=name or rasters[0].name or 'merged', attrs=out_attrs, ) return result
def _place_same_crs(src_data, src_bounds, src_shape, y_desc, out_bounds, out_shape, nodata, x_desc=False): """Place a same-CRS tile into the output grid by coordinate alignment. No reprojection needed -- just index the output rows/columns that overlap with the source tile and copy the data. The output grid is always north-up with ascending x (row 0 is the top of ``out_bounds``, column 0 is the left edge). When ``y_desc`` is ``False`` the source is y-ascending (row 0 is the bottom of ``src_bounds``); in that case the source window is flipped along y before being written so the placed data has the same north-up orientation as the rest of the output. Without this, ``merge([r])`` of a y-ascending raster silently differs from ``reproject(r, target_crs=r.crs)`` (#2186). ``x_desc`` mirrors the same logic on the horizontal axis: when True, ``src_data`` is laid out with column 0 at the maximum x, so the source window is reversed along its column axis before placement (#2183). """ out_h, out_w = out_shape src_h, src_w = src_shape o_left, o_bottom, o_right, o_top = out_bounds s_left, s_bottom, s_right, s_top = src_bounds o_res_x = (o_right - o_left) / out_w o_res_y = (o_top - o_bottom) / out_h s_res_x = (s_right - s_left) / src_w s_res_y = (s_top - s_bottom) / src_h # Output pixel range that this tile covers. The output is always # north-up so ``row_start`` is measured from ``o_top`` downward. col_start = int(round((s_left - o_left) / o_res_x)) col_end = int(round((s_right - o_left) / o_res_x)) row_start = int(round((o_top - s_top) / o_res_y)) row_end = int(round((o_top - s_bottom) / o_res_y)) # Clip to output bounds col_start_clip = max(0, col_start) col_end_clip = min(out_w, col_end) row_start_clip = max(0, row_start) row_end_clip = min(out_h, row_end) if col_start_clip >= col_end_clip or row_start_clip >= row_end_clip: return np.full(out_shape, nodata, dtype=np.float64) # Resolutions may differ slightly; if close enough, do direct copy res_ratio_x = s_res_x / o_res_x res_ratio_y = s_res_y / o_res_y if abs(res_ratio_x - 1.0) > 0.01 or abs(res_ratio_y - 1.0) > 0.01: return None # resolutions too different, fall back to reproject # Source column offset (columns always run left-to-right). src_col_start = col_start_clip - col_start src_c_end = min(src_col_start + (col_end_clip - col_start_clip), src_w) actual_cols = src_c_end - src_col_start out_data = np.full(out_shape, nodata, dtype=np.float64) if actual_cols <= 0: return out_data # Resolve the source column slice. When ``x_desc`` is True the # source has column 0 at max x, so we read a mirrored slice and # reverse it; otherwise we read the natural left-to-right slice. # ``src_col_start`` / ``src_c_end`` were computed above in # ascending-x source space and stay valid for both branches. if x_desc: c_lo = src_w - src_c_end c_hi = src_w - src_col_start else: c_lo = src_col_start c_hi = src_c_end if y_desc: # Source is north-up: rows go top-to-bottom, same as output. # Output row R corresponds to source row ``R - row_start``. src_row_start = row_start_clip - row_start src_r_end = min(src_row_start + (row_end_clip - row_start_clip), src_h) actual_rows = src_r_end - src_row_start if actual_rows <= 0: return out_data src_window = np.asarray( src_data[src_row_start:src_r_end, c_lo:c_hi], dtype=np.float64, ) if x_desc: src_window = src_window[:, ::-1] out_row_start = row_start_clip else: # Source is south-up: source row 0 sits at the bottom of # ``src_bounds``. Output row R corresponds to source row # ``row_end - 1 - R`` (with R and row_end both measured from # ``o_top`` downward, so larger R means lower latitude and # smaller source row index). Concretely, if src_h=4 and # row_end=4, output row 0 maps to source row 3, output row 3 # maps to source row 0 -- a vertical flip. Read a contiguous # ascending slice of source rows and reverse it so the # placed window comes out north-up. src_lo = max(0, row_end - row_end_clip) src_hi = min(src_h, row_end - row_start_clip) actual_rows = src_hi - src_lo if actual_rows <= 0: return out_data src_window = np.asarray( src_data[src_lo:src_hi, c_lo:c_hi], dtype=np.float64, )[::-1, :] if x_desc: src_window = src_window[:, ::-1] # The reversed slice's first row corresponds to source row # ``src_hi - 1``, which maps to output row ``row_end - src_hi``. # When the source is not clipped at the top, this equals # ``row_start_clip``; clipping at the source top shifts the # placement downward by the clipped row count. out_row_start = row_end - src_hi out_data[out_row_start:out_row_start + actual_rows, col_start_clip:col_start_clip + actual_cols] = src_window return out_data def _cast_merged_dtype(merged, out_dtype): """Cast a float64 merged mosaic back to an integer output dtype. Matches the round/clip/cast convention the reproject() workers use for integer sources (#2505/#3093). Must run after the NaN canonical sentinel has been converted back to the resolved output nodata -- for integer ``out_dtype`` that sentinel is guaranteed non-NaN by ``_detect_nodata``'s dtype hint, so no NaN survives to the cast. No-op for float output dtypes. """ if np.issubdtype(out_dtype, np.integer): info = np.iinfo(out_dtype) merged = np.clip(np.round(merged), info.min, info.max).astype(out_dtype) return merged def _merge_inmemory( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, transform_precision, out_dtype=np.float64, ): """In-memory merge using numpy. Detects same-CRS tiles and uses fast direct placement instead of reprojection. Each raster is reprojected using its own nodata sentinel and then canonicalized to NaN before the strategy merge so that mixed-sentinel inputs do not leak invalid pixels into the mosaic. After merging, the NaN canonical sentinel is converted back to the user-requested output ``nodata``. """ from ._crs_utils import _crs_from_wkt tgt_crs = _crs_from_wkt(tgt_wkt) arrays = [] for info in raster_infos: r_nd = info.get('raster_nodata', float('nan')) # Check if source CRS matches target (no reprojection needed) placed = None if info['src_crs'] == tgt_crs: placed = _place_same_crs( info['raster'].values, info['src_bounds'], info['src_shape'], info['y_desc'], out_bounds, out_shape, r_nd, x_desc=info.get('x_desc', False), ) if placed is not None: arr = placed else: arr = _reproject_chunk_numpy( info['raster'].values, info['src_bounds'], info['src_shape'], info['y_desc'], info['src_wkt'], tgt_wkt, out_bounds, out_shape, resampling, r_nd, transform_precision, source_x_desc=info.get('x_desc', False), ) # Canonicalize this raster's sentinel to NaN before the merge so # rasters with different sentinels merge correctly. if not np.isnan(r_nd): arr = np.asarray(arr, dtype=np.float64) arr = np.where(arr == r_nd, np.nan, arr) arrays.append(arr) merged = _merge_arrays_numpy(arrays, float('nan'), strategy) if not np.isnan(nodata): merged = np.where(np.isnan(merged), nodata, merged) return _cast_merged_dtype(merged, out_dtype) def _merge_block_adapter( block, block_info, raster_data_list, src_bounds_list, src_shape_list, y_desc_list, src_wkt_list, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, precision, src_footprints_tgt, raster_nodata_list, same_crs_list, x_desc_list=None, out_dtype=np.float64, ): """``map_blocks`` adapter for merge. Each raster is reprojected with its own input nodata sentinel and canonicalized to NaN before the strategy merge. The final result is converted back to the user-requested output ``nodata``. """ info = block_info[0] (row_start, row_end), (col_start, col_end) = info['array-location'] chunk_shape = (row_end - row_start, col_end - col_start) cb = _chunk_bounds(out_bounds, out_shape, row_start, row_end, col_start, col_end) # Only reproject rasters whose footprint overlaps this chunk arrays = [] for i in range(len(raster_data_list)): if (src_footprints_tgt[i] is not None and not _bounds_overlap(cb, src_footprints_tgt[i])): continue r_nd = raster_nodata_list[i] placed = None xd_i = x_desc_list[i] if x_desc_list is not None else False if same_crs_list[i]: # Same-CRS path: direct pixel placement (no resampling). # Pass the dask array straight through -- _place_same_crs # slices before np.asarray(), so np.asarray on the slice # materializes only the source window for this output chunk. # An eager .compute() here would materialize the full source # per output chunk, amplifying driver-side data flow by O(N). placed = _place_same_crs( raster_data_list[i], src_bounds_list[i], src_shape_list[i], y_desc_list[i], cb, chunk_shape, r_nd, x_desc=xd_i, ) if placed is not None: arr = placed else: arr = _reproject_chunk_numpy( raster_data_list[i], src_bounds_list[i], src_shape_list[i], y_desc_list[i], src_wkt_list[i], tgt_wkt, cb, chunk_shape, resampling, r_nd, precision, source_x_desc=xd_i, ) if not np.isnan(r_nd): arr = np.asarray(arr, dtype=np.float64) arr = np.where(arr == r_nd, np.nan, arr) arrays.append(arr) if not arrays: # Empty-chunk fills must match the template dtype, or a single # no-overlap chunk would promote the assembled output (#3096). return np.full(chunk_shape, nodata, dtype=out_dtype) merged = _merge_arrays_numpy(arrays, float('nan'), strategy) if not np.isnan(nodata): merged = np.where(np.isnan(merged), nodata, merged) return _cast_merged_dtype(merged, out_dtype) def _merge_dask( raster_infos, tgt_wkt, out_bounds, out_shape, resampling, nodata, strategy, chunk_size, transform_precision, out_dtype=np.float64, ): """Dask merge backend using ``map_blocks``.""" import functools import dask.array as da row_chunks, col_chunks = _compute_chunk_layout(out_shape, chunk_size) # Prepare lists for the worker data_list = [info['raster'].data for info in raster_infos] bounds_list = [info['src_bounds'] for info in raster_infos] shape_list = [info['src_shape'] for info in raster_infos] ydesc_list = [info['y_desc'] for info in raster_infos] xdesc_list = [info.get('x_desc', False) for info in raster_infos] wkt_list = [info['src_wkt'] for info in raster_infos] rnodata_list = [ info.get('raster_nodata', float('nan')) for info in raster_infos ] # Precompute source footprints in target CRS footprints = [ _source_footprint_in_target(bounds_list[i], wkt_list[i], tgt_wkt) for i in range(len(raster_infos)) ] # Precompute CRS-equality flags so per-block adapters can shortcut to # direct pixel placement (matches the eager _merge_inmemory path). from ._crs_utils import _crs_from_wkt tgt_crs = _crs_from_wkt(tgt_wkt) same_crs_list = [ bool(_crs_from_wkt(wkt_list[i]) == tgt_crs) for i in range(len(raster_infos)) ] # Bind via partial to prevent map_blocks from adding dask arrays # in data_list as whole-array dependencies. bound_adapter = functools.partial( _merge_block_adapter, raster_data_list=data_list, src_bounds_list=bounds_list, src_shape_list=shape_list, y_desc_list=ydesc_list, src_wkt_list=wkt_list, tgt_wkt=tgt_wkt, out_bounds=out_bounds, out_shape=out_shape, resampling=resampling, nodata=nodata, strategy=strategy, precision=transform_precision, src_footprints_tgt=footprints, raster_nodata_list=rnodata_list, same_crs_list=same_crs_list, x_desc_list=xdesc_list, out_dtype=out_dtype, ) # The template dtype must match what the chunks actually return, or # the graph advertises one dtype and computes another (#3096). template = da.empty( out_shape, dtype=out_dtype, chunks=(row_chunks, col_chunks) ) return da.map_blocks( bound_adapter, template, dtype=out_dtype, meta=np.array((), dtype=out_dtype), **_dask_task_name_kwargs('xrspatial.merge'), )