Source code for xrspatial.pathfinding

import heapq
import warnings
from collections import OrderedDict
from math import sqrt
from typing import Optional, Union

import numpy as np
import xarray as xr

try:
    import dask.array as da
    import dask
except ImportError:
    da = None
    dask = None

from xrspatial.cost_distance import _heap_push, _heap_pop
from xrspatial.utils import (
    _validate_raster,
    get_dataarray_resolution, ngjit,
    has_cuda_and_cupy, is_cupy_array, is_dask_cupy, has_dask_array,
)

NONE = -1

# Maximum waypoint count for multi_stop_search.  optimize_order builds an
# N x N distance matrix and runs N(N-1)/2 A* calls (O(N^3) when stitched
# with 2-opt), so unbounded N is a CPU DoS.
_MAX_WAYPOINTS = 1000


def _get_pixel_id(point, raster, xdim=None, ydim=None):
    # get location in `raster` pixel space for `point` in y-x coordinate space
    # point: (y, x) - coordinates of the point
    # xdim: name of the x coordinate dimension in input `raster`.
    # ydim: name of the x coordinate dimension in input `raster`

    if ydim is None:
        ydim = raster.dims[-2]
    if xdim is None:
        xdim = raster.dims[-1]
    y_coords = raster.coords[ydim].data
    x_coords = raster.coords[xdim].data

    cellsize_x, cellsize_y = get_dataarray_resolution(raster, xdim, ydim)
    py = int(abs(point[0] - y_coords[0]) / cellsize_y)
    px = int(abs(point[1] - x_coords[0]) / cellsize_x)

    # return index of row and column where the `point` located.
    return py, px


@ngjit
def _is_not_crossable(cell_value, barriers):
    # nan cell is not walkable
    if np.isnan(cell_value):
        return True

    for i in barriers:
        if cell_value == i:
            return True
    return False


def _is_not_crossable_py(cell_value, barriers):
    """Pure Python version of _is_not_crossable for the dask A* loop."""
    if np.isnan(cell_value):
        return True
    for b in barriers:
        if cell_value == b:
            return True
    return False


@ngjit
def _is_inside(py, px, h, w):
    inside = True
    if px < 0 or px >= w:
        inside = False
    if py < 0 or py >= h:
        inside = False
    return inside


@ngjit
def _find_nearest_pixel(py, px, data, barriers):
    # if the cell is already valid, return itself
    if not _is_not_crossable(data[py, px], barriers):
        return py, px

    height, width = data.shape
    # init min distance as max possible distance (pixel space)
    min_distance = np.sqrt(float((height - 1) ** 2 + (width - 1) ** 2))
    # return of the function
    nearest_y = NONE
    nearest_x = NONE
    for y in range(height):
        for x in range(width):
            if not _is_not_crossable(data[y, x], barriers):
                d = np.sqrt(float((x - px) ** 2 + (y - py) ** 2))
                if d < min_distance:
                    min_distance = d
                    nearest_y = y
                    nearest_x = x

    return nearest_y, nearest_x


def _neighborhood_structure(cellsize_x, cellsize_y, connectivity=8):
    """Return (dy, dx, dd) with cellsize-scaled geometric distances."""
    if connectivity == 8:
        dy = np.array([-1, -1, -1, 0, 0, 1, 1, 1], dtype=np.int64)
        dx = np.array([-1, 0, 1, -1, 1, -1, 0, 1], dtype=np.int64)
        dd = np.array([
            sqrt(cellsize_y ** 2 + cellsize_x ** 2),   # (-1,-1)
            cellsize_y,                                  # (-1, 0)
            sqrt(cellsize_y ** 2 + cellsize_x ** 2),   # (-1,+1)
            cellsize_x,                                  # ( 0,-1)
            cellsize_x,                                  # ( 0,+1)
            sqrt(cellsize_y ** 2 + cellsize_x ** 2),   # (+1,-1)
            cellsize_y,                                  # (+1, 0)
            sqrt(cellsize_y ** 2 + cellsize_x ** 2),   # (+1,+1)
        ], dtype=np.float64)
    else:
        dy = np.array([0, -1, 1, 0], dtype=np.int64)
        dx = np.array([-1, 0, 0, 1], dtype=np.int64)
        dd = np.array([cellsize_x, cellsize_y, cellsize_y, cellsize_x],
                      dtype=np.float64)
    return dy, dx, dd


# ---------------------------------------------------------------------------
# Memory safety helpers
# ---------------------------------------------------------------------------

def _available_memory_bytes():
    """Best-effort estimate of available memory in bytes."""
    # Try /proc/meminfo (Linux)
    try:
        with open('/proc/meminfo', 'r') as f:
            for line in f:
                if line.startswith('MemAvailable:'):
                    return int(line.split()[1]) * 1024  # kB → bytes
    except (OSError, ValueError, IndexError):
        pass
    # Try psutil
    try:
        import psutil
        return psutil.virtual_memory().available
    except (ImportError, AttributeError):
        pass
    # Fallback: 2 GB
    return 2 * 1024 ** 3


def _check_memory(height, width):
    """Raise MemoryError if A* arrays would exceed 80% of available RAM.

    The numba A* kernel allocates ~65 bytes per pixel (parent arrays,
    g-cost, visited, heap arrays, path image, friction).
    """
    required = height * width * 65
    available = _available_memory_bytes()
    if required > 0.8 * available:
        raise MemoryError(
            f"A* on a {height}x{width} grid requires ~{required / 1e9:.1f} GB "
            f"but only {available / 1e9:.1f} GB is available. "
            f"Use search_radius to limit the search area, "
            f"or use a dask-backed array for out-of-core pathfinding."
        )


@ngjit
def _reconstruct_path(path_img, parent_ys, parent_xs, g_cost,
                      start_py, start_px, goal_py, goal_px):
    # construct path output image as a 2d array with NaNs for non-path pixels,
    # and the value of the path pixels being the g-cost up to that point
    current_x = goal_px
    current_y = goal_py

    if parent_xs[current_y, current_x] != NONE and \
            parent_ys[current_y, current_x] != NONE:
        # exist path from start to goal
        # add cost at start
        path_img[start_py, start_px] = g_cost[start_py, start_px]
        # add cost along the path
        while current_x != start_px or current_y != start_py:
            # value of a path pixel is the cost up to that point
            path_img[current_y, current_x] = g_cost[current_y, current_x]
            parent_y = parent_ys[current_y, current_x]
            parent_x = parent_xs[current_y, current_x]
            current_y = parent_y
            current_x = parent_x
    return


@ngjit
def _a_star_search(data, path_img, start_py, start_px, goal_py, goal_px,
                   barriers, dy, dx, dd, friction, f_min, use_friction,
                   cellsize_x, cellsize_y):

    height, width = data.shape
    n_neighbors = len(dy)

    # parent of the (i, j) pixel is the pixel at
    # (parent_ys[i, j], parent_xs[i, j])
    parent_ys = np.ones((height, width), dtype=np.int64) * NONE
    parent_xs = np.ones((height, width), dtype=np.int64) * NONE

    # parent of start is itself
    parent_ys[start_py, start_px] = start_py
    parent_xs[start_py, start_px] = start_px

    # g-cost: distance from start to the current node
    g_cost = np.full((height, width), np.inf, dtype=np.float64)

    visited = np.zeros((height, width), dtype=np.int8)

    # Heap arrays
    max_heap = height * width
    h_keys = np.empty(max_heap, dtype=np.float64)
    h_rows = np.empty(max_heap, dtype=np.int64)
    h_cols = np.empty(max_heap, dtype=np.int64)
    h_size = 0

    if not _is_not_crossable(data[start_py, start_px], barriers):
        # Check friction at start when using friction
        if use_friction:
            f_start_val = friction[start_py, start_px]
            if not (np.isfinite(f_start_val) and f_start_val > 0.0):
                return

        g_cost[start_py, start_px] = 0.0

        # Compute heuristic for start
        dy_goal = abs(start_py - goal_py) * cellsize_y
        dx_goal = abs(start_px - goal_px) * cellsize_x
        h = np.sqrt(dy_goal ** 2 + dx_goal ** 2)
        if use_friction:
            h *= f_min

        h_size = _heap_push(h_keys, h_rows, h_cols, h_size,
                            h, start_py, start_px)

    while h_size > 0:
        f_u, py, px, h_size = _heap_pop(h_keys, h_rows, h_cols, h_size)

        if visited[py, px]:
            continue
        visited[py, px] = 1

        # found the goal
        if py == goal_py and px == goal_px:
            _reconstruct_path(path_img, parent_ys, parent_xs,
                              g_cost, start_py, start_px,
                              goal_py, goal_px)
            return

        g_u = g_cost[py, px]

        # visit neighborhood
        for i in range(n_neighbors):
            ny = py + dy[i]
            nx = px + dx[i]

            if ny < 0 or ny >= height or nx < 0 or nx >= width:
                continue
            if visited[ny, nx]:
                continue
            if _is_not_crossable(data[ny, nx], barriers):
                continue

            # Compute edge cost
            if use_friction:
                f_u_val = friction[py, px]
                f_v_val = friction[ny, nx]
                # impassable if friction is NaN or non-positive
                if not (np.isfinite(f_v_val) and f_v_val > 0.0):
                    continue
                edge_cost = dd[i] * (f_u_val + f_v_val) * 0.5
            else:
                edge_cost = dd[i]

            new_g = g_u + edge_cost

            if new_g < g_cost[ny, nx]:
                g_cost[ny, nx] = new_g
                parent_ys[ny, nx] = py
                parent_xs[ny, nx] = px

                # Compute heuristic
                dy_goal = abs(ny - goal_py) * cellsize_y
                dx_goal = abs(nx - goal_px) * cellsize_x
                h = np.sqrt(dy_goal ** 2 + dx_goal ** 2)
                if use_friction:
                    h *= f_min

                f_val = new_g + h
                h_size = _heap_push(h_keys, h_rows, h_cols, h_size,
                                    f_val, ny, nx)

    return


# ---------------------------------------------------------------------------
# Bounded A* (sub-region search)
# ---------------------------------------------------------------------------

def _bounded_a_star_sub(surface_data, friction_data, start_py, start_px,
                        goal_py, goal_px, barriers, dy, dx, dd,
                        f_min, use_friction, cellsize_x, cellsize_y,
                        search_radius, h, w):
    """Run A* on a sub-region around start and goal.

    Returns ``(sub_path_img, min_row, min_col)`` where *sub_path_img*
    is a small 2-D array covering only the bounding box.
    """
    min_row = max(0, min(start_py, goal_py) - search_radius)
    max_row = min(h, max(start_py, goal_py) + search_radius + 1)
    min_col = max(0, min(start_px, goal_px) - search_radius)
    max_col = min(w, max(start_px, goal_px) + search_radius + 1)

    sub_surface = np.ascontiguousarray(
        surface_data[min_row:max_row, min_col:max_col])
    if use_friction:
        sub_friction = np.ascontiguousarray(
            friction_data[min_row:max_row, min_col:max_col])
    else:
        sub_h = max_row - min_row
        sub_w = max_col - min_col
        sub_friction = np.ones((sub_h, sub_w), dtype=np.float64)

    local_start_py = start_py - min_row
    local_start_px = start_px - min_col
    local_goal_py = goal_py - min_row
    local_goal_px = goal_px - min_col

    sub_path_img = np.full(sub_surface.shape, np.nan, dtype=np.float64)
    _a_star_search(sub_surface, sub_path_img,
                   local_start_py, local_start_px,
                   local_goal_py, local_goal_px,
                   barriers, dy, dx, dd,
                   sub_friction, f_min, use_friction,
                   cellsize_x, cellsize_y)

    return sub_path_img, min_row, min_col


def _bounded_a_star(surface_data, friction_data, start_py, start_px,
                    goal_py, goal_px, barriers, dy, dx, dd,
                    f_min, use_friction, cellsize_x, cellsize_y,
                    search_radius, h, w):
    """Run bounded A* and embed the result into a full-size path array."""
    sub_path, min_row, min_col = _bounded_a_star_sub(
        surface_data, friction_data, start_py, start_px,
        goal_py, goal_px, barriers, dy, dx, dd,
        f_min, use_friction, cellsize_x, cellsize_y,
        search_radius, h, w)

    path_img = np.full((h, w), np.nan, dtype=np.float64)
    sh, sw = sub_path.shape
    path_img[min_row:min_row + sh, min_col:min_col + sw] = sub_path
    return path_img


# ---------------------------------------------------------------------------
# Hierarchical pathfinding (HPA*)
# ---------------------------------------------------------------------------

@ngjit
def _coarsen_surface(data, factor, barriers):
    """Coarsen surface by *factor*.

    A coarse cell is passable if ANY fine cell in the block is
    non-barrier.  Value is the mean of non-barrier fine cells.
    """
    h, w = data.shape
    ch = (h + factor - 1) // factor
    cw = (w + factor - 1) // factor
    coarse = np.full((ch, cw), np.nan, dtype=np.float64)

    for ci in range(ch):
        for cj in range(cw):
            r0 = ci * factor
            r1 = min(r0 + factor, h)
            c0 = cj * factor
            c1 = min(c0 + factor, w)
            total = 0.0
            count = 0
            for r in range(r0, r1):
                for c in range(c0, c1):
                    if not _is_not_crossable(data[r, c], barriers):
                        total += data[r, c]
                        count += 1
            if count > 0:
                coarse[ci, cj] = total / count

    return coarse


@ngjit
def _coarsen_friction(friction, factor):
    """Coarsen friction by *factor*. Mean of positive finite values."""
    h, w = friction.shape
    ch = (h + factor - 1) // factor
    cw = (w + factor - 1) // factor
    coarse = np.full((ch, cw), np.nan, dtype=np.float64)

    for ci in range(ch):
        for cj in range(cw):
            r0 = ci * factor
            r1 = min(r0 + factor, h)
            c0 = cj * factor
            c1 = min(c0 + factor, w)
            total = 0.0
            count = 0
            for r in range(r0, r1):
                for c in range(c0, c1):
                    v = friction[r, c]
                    if np.isfinite(v) and v > 0:
                        total += v
                        count += 1
            if count > 0:
                coarse[ci, cj] = total / count

    return coarse


def _hpa_star_search(surface_data, friction_data, start_py, start_px,
                     goal_py, goal_px, barriers, dy, dx, dd,
                     f_min, use_friction, cellsize_x, cellsize_y, h, w):
    """Hierarchical pathfinding: coarsen -> route on coarse grid -> refine.

    Uses the existing ``_a_star_search`` kernel on a coarsened grid to
    find a global route, then refines each segment with bounded A*.
    """
    factor = max(16, int(np.sqrt(max(h, w))))

    # --- Coarsen ---
    coarse_surface = _coarsen_surface(surface_data, factor, barriers)
    if use_friction:
        coarse_friction = _coarsen_friction(friction_data, factor)
    else:
        coarse_friction = np.ones(coarse_surface.shape, dtype=np.float64)

    ch, cw = coarse_surface.shape

    # Coarse start / goal (clamped)
    cs_py = min(start_py // factor, ch - 1)
    cs_px = min(start_px // factor, cw - 1)
    cg_py = min(goal_py // factor, ch - 1)
    cg_px = min(goal_px // factor, cw - 1)

    # Neighbourhood for coarse grid
    coarse_cx = cellsize_x * factor
    coarse_cy = cellsize_y * factor
    c_dy, c_dx, c_dd = _neighborhood_structure(coarse_cx, coarse_cy, 8)

    # --- Route on coarse grid ---
    coarse_path = np.full((ch, cw), np.nan, dtype=np.float64)
    _a_star_search(coarse_surface, coarse_path,
                   cs_py, cs_px, cg_py, cg_px,
                   barriers, c_dy, c_dx, c_dd,
                   coarse_friction, f_min, use_friction,
                   coarse_cx, coarse_cy)

    # Extract ordered waypoints (sorted by ascending cost)
    path_cells = []
    for r in range(ch):
        for c in range(cw):
            if np.isfinite(coarse_path[r, c]):
                path_cells.append((coarse_path[r, c], r, c))

    if not path_cells:
        return np.full((h, w), np.nan, dtype=np.float64)

    path_cells.sort()

    # Convert coarse waypoints to fine-grid coordinates (block centres)
    waypoints = []
    for _, cr, cc in path_cells:
        fr = min(cr * factor + factor // 2, h - 1)
        fc = min(cc * factor + factor // 2, w - 1)
        waypoints.append((fr, fc))

    # Exact start / goal
    waypoints[0] = (start_py, start_px)
    waypoints[-1] = (goal_py, goal_px)

    # --- Refine segment by segment ---
    path_img = np.full((h, w), np.nan, dtype=np.float64)
    cumulative_cost = 0.0

    for seg_idx in range(len(waypoints) - 1):
        s_py, s_px = waypoints[seg_idx]
        g_py, g_px = waypoints[seg_idx + 1]

        if s_py == g_py and s_px == g_px:
            continue

        base_radius = 2 * factor
        sub_path = None
        min_row = min_col = 0

        for multiplier in (1, 2, 4, 8):
            radius = base_radius * multiplier
            sub_path, min_row, min_col = _bounded_a_star_sub(
                surface_data, friction_data,
                s_py, s_px, g_py, g_px,
                barriers, dy, dx, dd,
                f_min, use_friction, cellsize_x, cellsize_y,
                radius, h, w)

            local_gy = g_py - min_row
            local_gx = g_px - min_col
            if np.isfinite(sub_path[local_gy, local_gx]):
                break

        local_gy = g_py - min_row
        local_gx = g_px - min_col
        if sub_path is None or not np.isfinite(sub_path[local_gy, local_gx]):
            return path_img  # partial result

        seg_goal_cost = sub_path[local_gy, local_gx]
        sh, sw = sub_path.shape

        # Stitch into full output with cost offset
        mask = np.isfinite(sub_path)
        if seg_idx > 0:
            # Don't overwrite junction (already written as previous goal)
            local_sy = s_py - min_row
            local_sx = s_px - min_col
            mask[local_sy, local_sx] = False

        target = path_img[min_row:min_row + sh, min_col:min_col + sw]
        target[mask] = sub_path[mask] + cumulative_cost
        cumulative_cost += seg_goal_cost

    return path_img


# ---------------------------------------------------------------------------
# LRU chunk cache for dask A*
# ---------------------------------------------------------------------------

class _ChunkCache:
    """OrderedDict-based LRU cache for dask chunks."""

    def __init__(self, maxsize=128):
        self._cache = OrderedDict()
        self._maxsize = maxsize

    def get(self, key, loader):
        """Return cached chunk or call *loader()*, evicting oldest if full."""
        if key in self._cache:
            self._cache.move_to_end(key)
            return self._cache[key]
        value = loader()
        if len(self._cache) >= self._maxsize:
            self._cache.popitem(last=False)
        self._cache[key] = value
        return value


# ---------------------------------------------------------------------------
# Sparse dask A*
# ---------------------------------------------------------------------------

def _a_star_dask(surface_da, friction_da, start_py, start_px,
                 goal_py, goal_px, barriers, dy, dx, dd,
                 f_min, use_friction, cellsize_x, cellsize_y, is_cupy):
    """Run A* on a dask-backed array, loading chunks on demand.

    Returns a list of (row, col, cost) tuples for path pixels,
    or [] if no path exists.
    """
    height, width = surface_da.shape
    n_neighbors = len(dy)

    # Chunk boundaries (cumulative sums of chunk sizes)
    row_chunks = np.array(surface_da.chunks[0])
    col_chunks = np.array(surface_da.chunks[1])
    row_bounds = np.cumsum(row_chunks)
    col_bounds = np.cumsum(col_chunks)

    surface_cache = _ChunkCache()
    friction_cache = _ChunkCache() if use_friction else None

    def _load_chunk(da_arr, cache, iy, ix):
        """Load and cache a single chunk, converting cupy->numpy."""
        def loader():
            block = da_arr.blocks[iy, ix].compute()
            if is_cupy:
                block = block.get()
            return np.asarray(block, dtype=np.float64)
        return cache.get((iy, ix), loader)

    def _get_value(da_arr, cache, r, c):
        """Get a scalar value at global (r, c) via chunk cache."""
        iy = int(np.searchsorted(row_bounds, r, side='right'))
        ix = int(np.searchsorted(col_bounds, c, side='right'))
        chunk = _load_chunk(da_arr, cache, iy, ix)
        local_r = r - (int(row_bounds[iy - 1]) if iy > 0 else 0)
        local_c = c - (int(col_bounds[ix - 1]) if ix > 0 else 0)
        return float(chunk[local_r, local_c])

    # Check start
    start_val = _get_value(surface_da, surface_cache, start_py, start_px)
    if _is_not_crossable_py(start_val, barriers):
        return []

    if use_friction:
        f_start = _get_value(friction_da, friction_cache, start_py, start_px)
        if not (np.isfinite(f_start) and f_start > 0.0):
            return []

    # A* data structures (sparse — dict/set, not full arrays)
    g_cost = {(start_py, start_px): 0.0}
    parent = {}
    visited = set()

    counter = 0  # tie-breaker for stable heap ordering

    # Heuristic for start
    dy_goal = abs(start_py - goal_py) * cellsize_y
    dx_goal = abs(start_px - goal_px) * cellsize_x
    h = sqrt(dy_goal ** 2 + dx_goal ** 2)
    if use_friction:
        h *= f_min

    heap = [(h, counter, start_py, start_px)]

    while heap:
        f_u, _, py, px = heapq.heappop(heap)

        if (py, px) in visited:
            continue
        visited.add((py, px))

        # Found goal — reconstruct path
        if py == goal_py and px == goal_px:
            path = []
            cr, cc = goal_py, goal_px
            path.append((cr, cc, g_cost[(cr, cc)]))
            while (cr, cc) in parent:
                cr, cc = parent[(cr, cc)]
                path.append((cr, cc, g_cost[(cr, cc)]))
            return path

        g_u = g_cost[(py, px)]

        for i in range(n_neighbors):
            ny = py + int(dy[i])
            nx = px + int(dx[i])

            if ny < 0 or ny >= height or nx < 0 or nx >= width:
                continue
            if (ny, nx) in visited:
                continue

            n_val = _get_value(surface_da, surface_cache, ny, nx)
            if _is_not_crossable_py(n_val, barriers):
                continue

            if use_friction:
                f_u_val = _get_value(friction_da, friction_cache, py, px)
                f_v_val = _get_value(friction_da, friction_cache, ny, nx)
                if not (np.isfinite(f_v_val) and f_v_val > 0.0):
                    continue
                edge_cost = float(dd[i]) * (f_u_val + f_v_val) * 0.5
            else:
                edge_cost = float(dd[i])

            new_g = g_u + edge_cost

            if new_g < g_cost.get((ny, nx), float('inf')):
                g_cost[(ny, nx)] = new_g
                parent[(ny, nx)] = (py, px)

                dy_goal = abs(ny - goal_py) * cellsize_y
                dx_goal = abs(nx - goal_px) * cellsize_x
                h = sqrt(dy_goal ** 2 + dx_goal ** 2)
                if use_friction:
                    h *= f_min

                counter += 1
                heapq.heappush(heap, (new_g + h, counter, ny, nx))

    return []  # no path


# ---------------------------------------------------------------------------
# Sparse path → lazy dask output
# ---------------------------------------------------------------------------

def _path_to_dask_array(path_pixels, shape, chunks, is_cupy):
    """Convert sparse path list to a lazy dask array of the original shape.

    *path_pixels* is a list of ``(row, col, cost)`` tuples.
    Non-path pixels are NaN.
    """
    height, width = shape
    row_chunks = chunks[0]
    col_chunks = chunks[1]
    row_bounds = np.cumsum(row_chunks)
    col_bounds = np.cumsum(col_chunks)

    # Group path pixels by chunk
    chunk_pixels = {}  # {(iy, ix): [(local_r, local_c, cost), ...]}
    for r, c, cost in path_pixels:
        iy = int(np.searchsorted(row_bounds, r, side='right'))
        ix = int(np.searchsorted(col_bounds, c, side='right'))
        local_r = r - (int(row_bounds[iy - 1]) if iy > 0 else 0)
        local_c = c - (int(col_bounds[ix - 1]) if ix > 0 else 0)
        chunk_pixels.setdefault((iy, ix), []).append(
            (local_r, local_c, cost))

    n_row_chunks = len(row_chunks)
    n_col_chunks = len(col_chunks)

    if is_cupy:
        import cupy

        @dask.delayed
        def _make_block_cupy(ch, cw, pixels):
            block = np.full((ch, cw), np.nan, dtype=np.float64)
            for lr, lc, cost in pixels:
                block[lr, lc] = cost
            return cupy.asarray(block)

        blocks = []
        for iy in range(n_row_chunks):
            row = []
            for ix in range(n_col_chunks):
                ch = int(row_chunks[iy])
                cw = int(col_chunks[ix])
                pixels = chunk_pixels.get((iy, ix), [])
                row.append(
                    da.from_delayed(
                        _make_block_cupy(ch, cw, pixels),
                        shape=(ch, cw),
                        dtype=np.float64,
                        meta=cupy.array((), dtype=np.float64),
                    )
                )
            blocks.append(row)
    else:
        @dask.delayed
        def _make_block(ch, cw, pixels):
            block = np.full((ch, cw), np.nan, dtype=np.float64)
            for lr, lc, cost in pixels:
                block[lr, lc] = cost
            return block

        blocks = []
        for iy in range(n_row_chunks):
            row = []
            for ix in range(n_col_chunks):
                ch = int(row_chunks[iy])
                cw = int(col_chunks[ix])
                pixels = chunk_pixels.get((iy, ix), [])
                row.append(
                    da.from_delayed(
                        _make_block(ch, cw, pixels),
                        shape=(ch, cw),
                        dtype=np.float64,
                        meta=np.array((), dtype=np.float64),
                    )
                )
            blocks.append(row)

    return da.block(blocks)






# ---------------------------------------------------------------------------
# Multi-stop routing
# ---------------------------------------------------------------------------

def _held_karp(dist, start, end):
    """Exact TSP with fixed start and end via Held-Karp bitmask DP.

    Parameters
    ----------
    dist : 2-D array-like, shape (N, N)
        Pairwise costs.  ``dist[i][j]`` is the cost from city *i* to *j*.
    start, end : int
        Indices that must be first and last in the tour.

    Returns
    -------
    (order, total_cost) : (list[int], float)
    """
    n = len(dist)
    if n == 2:
        return [start, end], dist[start][end]

    # Cities to visit between start and end
    mid = [i for i in range(n) if i != start and i != end]
    nm = len(mid)
    INF = float('inf')

    # dp[(mask, city_idx_in_mid)] = min cost to reach city from start
    # visiting exactly the cities indicated by mask
    dp = [[INF] * nm for _ in range(1 << nm)]
    parent = [[-1] * nm for _ in range(1 << nm)]

    # Base: start -> each mid city
    for j, c in enumerate(mid):
        dp[1 << j][j] = dist[start][c]

    for mask in range(1, 1 << nm):
        for j in range(nm):
            if not (mask & (1 << j)):
                continue
            if dp[mask][j] == INF:
                continue
            for k in range(nm):
                if mask & (1 << k):
                    continue
                new_mask = mask | (1 << k)
                new_cost = dp[mask][j] + dist[mid[j]][mid[k]]
                if new_cost < dp[new_mask][k]:
                    dp[new_mask][k] = new_cost
                    parent[new_mask][k] = j

    # Close tour to end
    full = (1 << nm) - 1
    best_cost = INF
    best_last = -1
    for j in range(nm):
        cost = dp[full][j] + dist[mid[j]][end]
        if cost < best_cost:
            best_cost = cost
            best_last = j

    # Reconstruct
    order_mid = []
    mask = full
    cur = best_last
    while cur != -1:
        order_mid.append(mid[cur])
        prev = parent[mask][cur]
        mask ^= (1 << cur)
        cur = prev
    order_mid.reverse()

    return [start] + order_mid + [end], best_cost


def _nearest_neighbor_2opt(dist, start, end):
    """Heuristic TSP for large N: nearest-neighbor + 2-opt with fixed endpoints.

    Parameters
    ----------
    dist : 2-D array-like, shape (N, N)
        Pairwise costs.
    start, end : int
        Fixed first and last indices.

    Returns
    -------
    (order, total_cost) : (list[int], float)
    """
    n = len(dist)
    remaining = set(range(n)) - {start, end}

    # Greedy nearest-neighbor construction
    tour = [start]
    cur = start
    while remaining:
        nearest = min(remaining, key=lambda c: dist[cur][c])
        tour.append(nearest)
        remaining.remove(nearest)
        cur = nearest
    tour.append(end)

    # 2-opt local search (only swap interior segment, keep endpoints fixed)
    def _tour_cost(t):
        return sum(dist[t[i]][t[i + 1]] for i in range(len(t) - 1))

    improved = True
    while improved:
        improved = False
        for i in range(1, len(tour) - 2):
            for j in range(i + 1, len(tour) - 1):
                # Reverse segment tour[i:j+1]
                new_tour = tour[:i] + tour[i:j + 1][::-1] + tour[j + 1:]
                if _tour_cost(new_tour) < _tour_cost(tour):
                    tour = new_tour
                    improved = True

    return tour, _tour_cost(tour)


def _optimize_waypoint_order(surface, waypoints, barriers, x, y,
                             connectivity, snap, friction, search_radius):
    """Build pairwise cost matrix and solve TSP with fixed endpoints.

    Returns reordered waypoints list.
    """
    n = len(waypoints)
    INF = float('inf')
    dist = [[INF] * n for _ in range(n)]

    for i in range(n):
        for j in range(n):
            if i == j:
                dist[i][j] = 0.0
                continue
            seg = a_star_search(
                surface, waypoints[i], waypoints[j],
                barriers=barriers, x=x, y=y,
                connectivity=connectivity,
                snap_start=snap, snap_goal=snap,
                friction=friction, search_radius=search_radius,
            )
            seg_data = seg.data
            if hasattr(seg_data, 'get'):
                seg_vals = seg_data.get()
            else:
                seg_vals = np.asarray(seg.values)
            goal_py, goal_px = _get_pixel_id(waypoints[j], surface, x, y)
            goal_cost = seg_vals[goal_py, goal_px]
            if np.isfinite(goal_cost):
                dist[i][j] = goal_cost

    # Fixed endpoints: first=0, last=n-1
    if n <= 12:
        order, _ = _held_karp(dist, 0, n - 1)
    else:
        order, _ = _nearest_neighbor_2opt(dist, 0, n - 1)

    return [waypoints[i] for i in order]


def multi_stop_search(surface: xr.DataArray,
                      waypoints: list,
                      barriers: list = [],
                      x: Optional[str] = 'x',
                      y: Optional[str] = 'y',
                      connectivity: int = 8,
                      snap: bool = False,
                      friction: xr.DataArray = None,
                      search_radius: Optional[int] = None,
                      optimize_order: bool = False) -> xr.DataArray:
    """Find the least-cost path visiting a sequence of waypoints in order.

    Wraps :func:`a_star_search` to route through *N* waypoints,
    stitching segments into a single cumulative-cost surface.  When
    ``optimize_order=True``, the interior waypoints are reordered to
    minimize total travel cost (TSP), keeping the first and last
    waypoints fixed.

    Parameters
    ----------
    surface : xr.DataArray
        2-D elevation / cost surface.
    waypoints : list of array-like
        Sequence of ``(y, x)`` coordinate pairs to visit.  Must contain
        at least two points.
    barriers : list, default=[]
        Surface values that are impassable.
    x, y : str, default ``'x'`` / ``'y'``
        Coordinate dimension names.
    connectivity : int, default=8
        4 or 8 connectivity.
    snap : bool, default=False
        Snap each waypoint to the nearest valid pixel.  Not supported
        with dask-backed arrays.
    friction : xr.DataArray, optional
        Friction surface (same shape as *surface*).
    search_radius : int, optional
        Passed to each :func:`a_star_search` call.
    optimize_order : bool, default=False
        Reorder interior waypoints to minimize total cost.  Uses exact
        Held-Karp when N <= 12, nearest-neighbor + 2-opt otherwise.

    Returns
    -------
    xr.DataArray
        Cumulative path cost surface.  Attributes include
        ``waypoint_order``, ``segment_costs``, and ``total_cost``.

    Raises
    ------
    ValueError
        If the surface is not 2-D, fewer than two waypoints are given,
        waypoints fall outside the surface bounds, or a segment is
        unreachable.
    """
    # --- Input validation ---
    _validate_raster(surface, func_name='multi_stop_search',
                     name='surface', ndim=2)

    if friction is not None:
        _validate_raster(friction, func_name='multi_stop_search',
                         name='friction', ndim=2)

    if len(waypoints) < 2:
        raise ValueError("at least 2 waypoints are required")

    if len(waypoints) > _MAX_WAYPOINTS:
        raise ValueError(
            f"multi_stop_search() supports at most {_MAX_WAYPOINTS} "
            f"waypoints, got {len(waypoints)}.  optimize_order is "
            f"O(N^3) so larger lists can hang the worker."
        )

    for idx, wp in enumerate(waypoints):
        if len(wp) != 2:
            raise ValueError(
                f"waypoint {idx} must have exactly 2 elements (y, x)")

    h, w = surface.shape
    for idx, wp in enumerate(waypoints):
        py, px = _get_pixel_id(wp, surface, x, y)
        if not _is_inside(py, px, h, w):
            raise ValueError(
                f"waypoint {idx} ({wp}) is outside the surface bounds")

    if friction is not None and friction.shape != surface.shape:
        raise ValueError("friction must have the same shape as surface")

    surface_data = surface.data
    _is_dask = da is not None and isinstance(surface_data, da.Array)
    if snap and _is_dask:
        raise ValueError(
            "snap is not supported with dask-backed arrays; "
            "ensure waypoints are valid before calling multi_stop_search")

    if optimize_order:
        if len(waypoints) < 3:
            warnings.warn(
                "optimize_order has no effect with fewer than 3 waypoints",
                stacklevel=2,
            )
        else:
            waypoints = _optimize_waypoint_order(
                surface, list(waypoints), barriers, x, y,
                connectivity, snap, friction, search_radius,
            )

    # --- Segment-by-segment routing ---
    path_data = np.full(surface.shape, np.nan, dtype=np.float64)
    cumulative_cost = 0.0
    segment_costs = []

    # Pre-compute pixel coords for all waypoints
    waypoint_pixels = [_get_pixel_id(wp, surface, x, y) for wp in waypoints]

    for i in range(len(waypoints) - 1):
        seg = a_star_search(
            surface, waypoints[i], waypoints[i + 1],
            barriers=barriers, x=x, y=y,
            connectivity=connectivity,
            snap_start=snap, snap_goal=snap,
            friction=friction, search_radius=search_radius,
        )
        seg_data = seg.data
        if hasattr(seg_data, 'get'):
            seg_vals = seg_data.get()  # cupy -> numpy
        else:
            seg_vals = np.asarray(seg.values)

        goal_py, goal_px = waypoint_pixels[i + 1]

        # If snap is on, the actual goal pixel may differ from the
        # requested one.  Find the pixel with maximum finite cost
        # (the true goal of this segment).
        if snap and not np.isfinite(seg_vals[goal_py, goal_px]):
            finite = np.isfinite(seg_vals)
            if finite.any():
                max_idx = np.nanargmax(seg_vals)
                goal_py, goal_px = np.unravel_index(max_idx, seg_vals.shape)
                waypoint_pixels[i + 1] = (goal_py, goal_px)

        seg_goal_cost = seg_vals[goal_py, goal_px]

        if not np.isfinite(seg_goal_cost):
            raise ValueError(
                f"no path between waypoints {i} and {i + 1}")

        mask = np.isfinite(seg_vals)
        if i > 0:
            # Don't overwrite the junction pixel (set by previous segment)
            sp_y, sp_x = waypoint_pixels[i]
            mask[sp_y, sp_x] = False

        path_data[mask] = seg_vals[mask] + cumulative_cost
        segment_costs.append(float(seg_goal_cost))
        cumulative_cost += seg_goal_cost

    path_agg = xr.DataArray(
        path_data,
        coords=surface.coords,
        dims=surface.dims,
        attrs={
            'waypoint_order': [tuple(wp) for wp in waypoints],
            'segment_costs': segment_costs,
            'total_cost': cumulative_cost,
        },
    )

    return path_agg