"""Mahalanobis distance for multi-band rasters.
Computes the Mahalanobis distance at each pixel from a reference
distribution, accounting for correlations between bands.
D(x) = sqrt((x - mu)^T * Sigma^-1 * (x - mu))
Supports numpy, cupy, dask+numpy, and dask+cupy backends.
"""
from typing import List, Optional
import numpy as np
import xarray as xr
from xrspatial.utils import (
ArrayTypeFunctionMapping,
_validate_raster,
has_cuda_and_cupy,
has_dask_array,
is_cupy_array,
is_dask_cupy,
validate_arrays,
)
try:
import dask
import dask.array as da
except ImportError:
da = None
dask = None
try:
import cupy
except ImportError:
cupy = None
# ---------------------------------------------------------------------------
# Memory guards
# ---------------------------------------------------------------------------
#
# Both the statistics phase and the per-pixel phase materialise float64
# working buffers of shape (n_bands, H*W). A conservative count of the
# live copies for each eager backend:
#
# stack (float64) : 8 bytes/cell * n_bands
# reshape/.T (float64) : 8 bytes/cell * n_bands (transpose forces a
# contiguous copy when
# passed to BLAS)
# centered/diff (float64) : 8 bytes/cell * n_bands
# diff @ inv_cov (float64) : 8 bytes/cell * n_bands
#
# Total ~32 bytes/cell * n_bands. The (n_bands, n_bands) inverse covariance
# is negligible by comparison.
_BYTES_PER_CELL_PER_BAND = 32
def _available_memory_bytes():
"""Best-effort estimate of available memory in bytes."""
# Try /proc/meminfo (Linux)
try:
with open('/proc/meminfo', 'r') as f:
for line in f:
if line.startswith('MemAvailable:'):
return int(line.split()[1]) * 1024 # kB -> bytes
except (OSError, ValueError, IndexError):
pass
# Try psutil
try:
import psutil
return psutil.virtual_memory().available
except (ImportError, AttributeError):
pass
# Fallback: 2 GB
return 2 * 1024 ** 3
def _available_gpu_memory_bytes():
"""Best-effort estimate of free GPU memory in bytes.
Returns 0 if CuPy / CUDA is unavailable or the query fails -- callers
use that as a sentinel meaning "no GPU info, skip the guard".
"""
try:
import cupy as _cp
free, _total = _cp.cuda.runtime.memGetInfo()
return int(free)
except Exception:
return 0
def _projected_bytes(n_bands, height, width):
return int(n_bands) * int(height) * int(width) * _BYTES_PER_CELL_PER_BAND
def _check_memory(n_bands, height, width):
"""Raise MemoryError if the host working set would exceed 50% of RAM."""
required = _projected_bytes(n_bands, height, width)
available = _available_memory_bytes()
if required > 0.5 * available:
raise MemoryError(
f"mahalanobis on {n_bands} bands of shape "
f"({height}, {width}) needs ~{required / 1e9:.1f} GB of "
f"working memory but only ~{available / 1e9:.1f} GB is "
f"available. Use a dask-backed DataArray for out-of-core "
f"processing, or pass smaller inputs."
)
def _check_gpu_memory(n_bands, height, width):
"""Raise MemoryError if the GPU working set would exceed 50% of free VRAM.
Returns silently when ``_available_gpu_memory_bytes`` cannot determine
the free memory -- e.g. on hosts without CUDA, where the kernel will
fail at the cupy boundary anyway.
"""
available = _available_gpu_memory_bytes()
if available <= 0:
return
required = _projected_bytes(n_bands, height, width)
if required > 0.5 * available:
raise MemoryError(
f"mahalanobis on {n_bands} bands of shape "
f"({height}, {width}) needs ~{required / 1e9:.1f} GB of "
f"GPU working memory but only ~{available / 1e9:.1f} GB is "
f"free on the active device. Use a dask+cupy DataArray for "
f"out-of-core processing, or pass smaller inputs."
)
# ---------------------------------------------------------------------------
# Phase 1: compute statistics (mean vector, inverse covariance)
# ---------------------------------------------------------------------------
def _compute_stats_numpy(bands_data):
"""Compute (mu, inv_cov) from a list of numpy 2D arrays."""
n_bands = len(bands_data)
stacked = np.stack([b.astype(np.float64) for b in bands_data], axis=0)
flat = stacked.reshape(n_bands, -1)
# only pixels where ALL bands are finite
valid = np.all(np.isfinite(flat), axis=0)
n_valid = int(np.sum(valid))
if n_valid < n_bands + 1:
raise ValueError(
f"Not enough valid pixels ({n_valid}) to compute statistics "
f"for {n_bands} bands. Need at least {n_bands + 1}."
)
flat_valid = flat[:, valid] # (N, n_valid)
mu = np.mean(flat_valid, axis=1) # (N,)
centered = flat_valid - mu[:, np.newaxis]
cov = (centered @ centered.T) / (n_valid - 1) # (N, N)
try:
inv_cov = np.linalg.inv(cov)
except np.linalg.LinAlgError:
raise ValueError(
"Covariance matrix is singular. Bands may be linearly "
"dependent or have zero variance."
)
return mu, inv_cov
def _compute_stats_cupy(bands_data):
"""Compute (mu, inv_cov) from a list of cupy 2D arrays.
Returns numpy arrays so the small statistics live on host.
"""
n_bands = len(bands_data)
stacked = cupy.stack(
[b.astype(cupy.float64) for b in bands_data], axis=0
)
flat = stacked.reshape(n_bands, -1)
valid = cupy.all(cupy.isfinite(flat), axis=0)
n_valid = int(cupy.sum(valid).get())
if n_valid < n_bands + 1:
raise ValueError(
f"Not enough valid pixels ({n_valid}) to compute statistics "
f"for {n_bands} bands. Need at least {n_bands + 1}."
)
flat_valid = flat[:, valid]
mu = cupy.mean(flat_valid, axis=1)
centered = flat_valid - mu[:, cupy.newaxis]
cov = (centered @ centered.T) / (n_valid - 1)
try:
inv_cov = cupy.linalg.inv(cov)
except cupy.linalg.LinAlgError:
raise ValueError(
"Covariance matrix is singular. Bands may be linearly "
"dependent or have zero variance."
)
return mu.get(), inv_cov.get()
def _compute_stats_dask(bands_data):
"""Compute (mu, inv_cov) from a list of dask 2D arrays.
Uses lazy reductions materialized in a single ``dask.compute()`` call.
Returns numpy arrays.
"""
n_bands = len(bands_data)
stacked = da.stack(
[b.astype(np.float64) for b in bands_data], axis=0
).rechunk({0: n_bands})
flat = stacked.reshape(n_bands, -1)
valid = da.all(da.isfinite(flat), axis=0)
n_valid_lazy = da.sum(valid)
flat_masked = da.where(valid[np.newaxis, :], flat, 0.0)
mu_sum = da.sum(flat_masked, axis=1)
# materialize in one pass
n_valid, mu_sum_val = dask.compute(n_valid_lazy, mu_sum)
n_valid = int(n_valid)
if n_valid < n_bands + 1:
raise ValueError(
f"Not enough valid pixels ({n_valid}) to compute statistics "
f"for {n_bands} bands. Need at least {n_bands + 1}."
)
mu = mu_sum_val / n_valid # small numpy array
# center and compute covariance lazily
centered = da.where(
valid[np.newaxis, :],
flat - mu[:, np.newaxis],
0.0,
)
cov_lazy = (centered @ centered.T) / (n_valid - 1)
cov = cov_lazy.compute()
# bring to numpy if cupy-backed
if is_cupy_array(cov):
mu = cupy.asnumpy(mu)
cov = cupy.asnumpy(cov)
try:
inv_cov = np.linalg.inv(cov)
except np.linalg.LinAlgError:
raise ValueError(
"Covariance matrix is singular. Bands may be linearly "
"dependent or have zero variance."
)
return mu.astype(np.float64), inv_cov.astype(np.float64)
# ---------------------------------------------------------------------------
# Phase 2: per-pixel distance
# ---------------------------------------------------------------------------
def _mahalanobis_pixel_numpy(stacked, mu, inv_cov):
"""Per-pixel Mahalanobis distance for a (N, H, W) numpy array.
Parameters
----------
stacked : numpy array, shape (N, H, W)
mu : numpy array, shape (N,)
inv_cov : numpy array, shape (N, N)
Returns
-------
numpy array, shape (H, W), float64
"""
n_bands, h, w = stacked.shape
flat = stacked.reshape(n_bands, -1).T # (pixels, N)
# NaN mask: any band NaN → output NaN
nan_mask = ~np.all(np.isfinite(flat), axis=1) # (pixels,)
diff = flat - mu[np.newaxis, :] # (pixels, N)
transformed = diff @ inv_cov # (pixels, N)
dist_sq = np.sum(transformed * diff, axis=1) # (pixels,)
dist_sq = np.maximum(dist_sq, 0.0) # numerical guard
result = np.sqrt(dist_sq)
result[nan_mask] = np.nan
return result.reshape(h, w)
def _mahalanobis_pixel_cupy(stacked, mu, inv_cov):
"""Per-pixel Mahalanobis distance for a (N, H, W) cupy array."""
mu_gpu = cupy.asarray(mu)
inv_cov_gpu = cupy.asarray(inv_cov)
n_bands, h, w = stacked.shape
flat = stacked.reshape(n_bands, -1).T # (pixels, N)
nan_mask = ~cupy.all(cupy.isfinite(flat), axis=1)
diff = flat - mu_gpu[cupy.newaxis, :]
transformed = diff @ inv_cov_gpu
dist_sq = cupy.sum(transformed * diff, axis=1)
dist_sq = cupy.maximum(dist_sq, 0.0)
result = cupy.sqrt(dist_sq)
result[nan_mask] = cupy.nan
return result.reshape(h, w)
# ---------------------------------------------------------------------------
# Per-backend entry points
# ---------------------------------------------------------------------------
def _mahalanobis_numpy(bands_data, mu, inv_cov):
stacked = np.stack([b.astype(np.float64) for b in bands_data], axis=0)
return _mahalanobis_pixel_numpy(stacked, mu, inv_cov)
def _mahalanobis_cupy(bands_data, mu, inv_cov):
stacked = cupy.stack(
[b.astype(cupy.float64) for b in bands_data], axis=0
)
return _mahalanobis_pixel_cupy(stacked, mu, inv_cov)
def _mahalanobis_dask_numpy(bands_data, mu, inv_cov):
stacked = da.stack(
[b.astype(np.float64) for b in bands_data], axis=0
).rechunk({0: len(bands_data)})
def _block_fn(block):
return _mahalanobis_pixel_numpy(block, mu, inv_cov)
return da.map_blocks(
_block_fn,
stacked,
drop_axis=0,
dtype=np.float64,
meta=np.array((), dtype=np.float64),
)
def _mahalanobis_dask_cupy(bands_data, mu, inv_cov):
stacked = da.stack(
[b.astype(np.float64) for b in bands_data], axis=0
).rechunk({0: len(bands_data)})
def _block_fn(block):
return _mahalanobis_pixel_cupy(block, mu, inv_cov)
return da.map_blocks(
_block_fn,
stacked,
drop_axis=0,
dtype=np.float64,
meta=cupy.array((), dtype=np.float64),
)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def mahalanobis(
bands: List[xr.DataArray],
mean: Optional[np.ndarray] = None,
inv_cov: Optional[np.ndarray] = None,
name: str = 'mahalanobis',
) -> xr.DataArray:
"""Compute per-pixel Mahalanobis distance from a multi-band reference.
Parameters
----------
bands : list of xr.DataArray
N 2-D rasters (same shape, dtype family, and chunk structure).
Must contain at least 2 bands.
mean : numpy array, shape (N,), optional
Mean vector of the reference distribution. Must be provided
together with *inv_cov*, or both omitted (auto-computed).
inv_cov : numpy array, shape (N, N), optional
Inverse covariance matrix. Must be provided together with
*mean*, or both omitted (auto-computed).
name : str
Name for the output DataArray.
Returns
-------
xr.DataArray
2-D float64 raster with same coords/dims/attrs as ``bands[0]``.
"""
# --- input validation ---
if len(bands) < 2:
raise ValueError("At least 2 bands are required.")
# Per-band dtype/ndim check. ``validate_arrays`` only enforces matching
# shape and array-type, so without this loop a boolean or other
# non-numeric DataArray would silently coerce to float64.
for i, band in enumerate(bands):
_validate_raster(
band,
func_name='mahalanobis',
name=f'bands[{i}]',
ndim=2,
)
validate_arrays(*bands)
if (mean is None) != (inv_cov is None):
raise ValueError(
"Provide both `mean` and `inv_cov`, or neither."
)
n_bands = len(bands)
ref = bands[0]
bands_data = [b.data for b in bands]
# --- memory guard (eager backends only) ---
# Dask paths process bounded chunks, so the per-task working set is
# capped by the user's chunk size rather than the full raster shape.
height, width = ref.shape
is_dask = has_dask_array() and isinstance(ref.data, da.Array)
is_cupy = has_cuda_and_cupy() and is_cupy_array(ref.data)
if not is_dask:
if is_cupy:
_check_gpu_memory(n_bands, height, width)
else:
_check_memory(n_bands, height, width)
# --- cast to float64 ---
# (handled inside each backend function)
# --- compute or validate statistics ---
if mean is not None:
mu = np.asarray(mean, dtype=np.float64)
icov = np.asarray(inv_cov, dtype=np.float64)
if mu.shape != (n_bands,):
raise ValueError(
f"`mean` shape {mu.shape} does not match "
f"number of bands ({n_bands},)."
)
if icov.shape != (n_bands, n_bands):
raise ValueError(
f"`inv_cov` shape {icov.shape} does not match "
f"expected ({n_bands}, {n_bands})."
)
else:
# auto-compute stats — dispatch by backend
if is_dask:
mu, icov = _compute_stats_dask(bands_data)
elif is_cupy:
mu, icov = _compute_stats_cupy(bands_data)
else:
mu, icov = _compute_stats_numpy(bands_data)
# --- per-pixel distance — dispatch by backend ---
mapper = ArrayTypeFunctionMapping(
numpy_func=_mahalanobis_numpy,
cupy_func=_mahalanobis_cupy,
dask_func=_mahalanobis_dask_numpy,
dask_cupy_func=_mahalanobis_dask_cupy,
)
out = mapper(ref)(bands_data, mu, icov)
return xr.DataArray(
out,
name=name,
dims=ref.dims,
coords=ref.coords,
attrs=ref.attrs,
)