Source code for earthdaily.earthdatastore.cube_utils._zonal

"""Zonal statistics computation for geospatial data analysis.

This module provides functionality for computing zonal statistics on xarray Datasets
using various methods including numpy and xvec. It supports parallel processing,
memory optimization, and various statistical operations.

Example:
    >>> import xarray as xr
    >>> import geopandas as gpd
    >>> dataset = xr.open_dataset("temperature.nc")
    >>> polygons = gpd.read_file("zones.geojson")
    >>> stats = zonal_stats(dataset, polygons, reducers=["mean", "max"])
"""

from typing import Union, List, Optional, Tuple, Dict
import logging
import time
from pathlib import Path

import numpy as np
import xarray as xr
import pandas as pd
import geopandas as gpd
from scipy.sparse import csr_matrix
from scipy.stats import mode
from tqdm.auto import trange
import psutil

from .preprocessing import rasterize

# Configure logging
logger = logging.getLogger(__name__)


class MemoryManager:
    """Manages memory allocation for large dataset processing."""

    @staticmethod
    def calculate_time_chunks(
        dataset: xr.Dataset, max_memory_mb: Optional[float] = None
    ) -> int:
        """Calculate optimal time chunks based on available memory.

        Args:
            dataset: Input xarray Dataset
            max_memory_mb: Maximum memory to use in megabytes

        Returns:
            int: Optimal number of time chunks
        """
        if max_memory_mb is None:
            max_memory_mb = psutil.virtual_memory().available / 1e6
            logger.info(f"Using maximum available memory: {max_memory_mb:.2f}MB")

        bytes_per_date = (dataset.nbytes / 1e6) / dataset.time.size * 3
        max_chunks = int(np.arange(0, max_memory_mb, bytes_per_date + 0.1).size)
        time_chunks = int(
            dataset.time.size / np.arange(0, dataset.time.size, max_chunks).size
        )

        logger.info(
            f"Estimated memory per date: {bytes_per_date:.2f}MB. Total: {(bytes_per_date*dataset.time.size):.2f}MB"
        )
        logger.info(
            f"Time chunks: {time_chunks} (total time steps: {dataset.time.size})"
        )

        return time_chunks


class SpatialIndexer:
    """Handles spatial indexing and rasterization operations."""

    @staticmethod
    def compute_sparse_matrix(data: np.ndarray) -> csr_matrix:
        """Compute sparse matrix from input data.

        Args:
            data: Input numpy array

        Returns:
            scipy.sparse.csr_matrix: Computed sparse matrix
        """
        cols = np.arange(data.size)
        return csr_matrix(
            (cols, (data.ravel(), cols)), shape=(data.max() + 1, data.size)
        )

    @staticmethod
    def get_sparse_indices(data: np.ndarray) -> List[Tuple[np.ndarray, ...]]:
        """Get sparse indices from input data.

        Args:
            data: Input numpy array

        Returns:
            List of index tuples
        """
        matrix = SpatialIndexer.compute_sparse_matrix(data)
        return [np.unravel_index(row.data, data.shape) for row in matrix]

    @staticmethod
    def rasterize_geometries(
        gdf: gpd.GeoDataFrame, dataset: xr.Dataset, all_touched: bool = False
    ) -> Tuple[np.ndarray, List[np.ndarray]]:
        """Rasterize geometries to match dataset resolution.

        Args:
            gdf: Input GeoDataFrame
            dataset: Reference dataset for rasterization
            all_touched: Whether to include all touched pixels

        Returns:
            Tuple containing features array and positions list
        """
        features = rasterize(gdf, dataset, all_touched=all_touched)
        positions = SpatialIndexer.get_sparse_indices(features)
        return features, positions


class StatisticalOperations:
    """Handles statistical computations on spatial data."""

    @staticmethod
    def zonal_stats(
        dataset: xr.Dataset, positions: List[np.ndarray], reducers: List[str]
    ) -> xr.DataArray:
        """Compute zonal statistics for given positions using specified reducers.

        Args:
            dataset: Input dataset
            positions: List of position arrays
            reducers: List of statistical operations to perform

        Returns:
            xarray.DataArray: Computed statistics

        Notes:
            Uses xarray's apply_ufunc for parallel processing and efficient computation
        """

        def _zonal_stats_ufunc(data, positions, reducers):
            """Inner function for parallel computation of zonal statistics."""
            zs = []
            for idx in range(len(positions)):
                field_stats = []
                for reducer in reducers:
                    field_arr = data[(...,) + tuple(positions[idx])]
                    if reducer == "mode":
                        field_arr = mode(field_arr, axis=-1, nan_policy="omit").mode
                    else:
                        func = (
                            f"nan{reducer}" if hasattr(np, f"nan{reducer}") else reducer
                        )
                        field_arr = getattr(np, func)(field_arr, axis=-1)
                    field_stats.append(field_arr)
                field_stats = np.asarray(field_stats)
                zs.append(field_stats)
            zs = np.asarray(zs)
            return zs.swapaxes(-1, 0).swapaxes(-1, -2)

        # Apply the function using xarray's parallel processing capabilities
        return xr.apply_ufunc(
            _zonal_stats_ufunc,
            dataset,
            vectorize=False,
            dask="parallelized",
            input_core_dims=[["y", "x"]],
            output_core_dims=[["feature", "zonal_statistics"]],
            exclude_dims=set(["x", "y"]),
            output_dtypes=[float],
            kwargs=dict(reducers=reducers, positions=positions),
            dask_gufunc_kwargs={
                "allow_rechunk": True,
                "output_sizes": dict(
                    feature=len(positions), zonal_statistics=len(reducers)
                ),
            },
        )


[docs] def zonal_stats( dataset: xr.Dataset, geometries: Union[gpd.GeoDataFrame, gpd.GeoSeries], method: str = "numpy", lazy_load: bool = True, max_memory_mb: Optional[float] = None, reducers: List[str] = ["mean"], all_touched: bool = True, preserve_columns: bool = True, buffer_meters: Optional[Union[int, float]] = None, **kwargs, ) -> xr.Dataset: """Calculate zonal statistics for xarray Dataset based on geometric boundaries. This function computes statistical summaries of Dataset values within each geometry's zone, supporting parallel processing through xarray's apply_ufunc and multiple computation methods. Parameters ---------- dataset : xarray.Dataset Input dataset containing variables for statistics computation. geoms : Union[geopandas.GeoDataFrame, geopandas.GeoSeries] Geometries defining the zones for statistics calculation. method : str, optional Method for computation. Options: - 'numpy': Uses numpy functions with parallel processing - 'xvec': Uses xvec library (must be installed) Default is 'numpy'. lazy_load : bool, optional If True, optimizes memory usage by loading chunks of data for 'numpy' method. Default is False. max_memory_mb : float, optional Maximum memory to use in megabytes. If None, uses maximum available memory. Default is None. reducers : list[str], optional List of statistical operations to perform. Functions should be numpy nan-functions (e.g., 'mean' uses np.nanmean). Default is ['mean']. all_touched : bool, optional If True, includes all pixels touched by geometries in computation. Default is True. preserve_columns : bool, optional If True, preserves all columns from input geometries in output. Default is True. buffer_meters : Union[int, float, None], optional Buffer distance in meters to apply to geometries before computation. Default is None. **kwargs : dict Additional keyword arguments passed to underlying computation functions. Returns ------- xarray.Dataset Dataset containing computed statistics with dimensions: - time (if present in input) - feature (number of geometries) - zonal_statistics (number of reducers) Additional coordinates include geometry WKT and preserved columns if requested. See Also -------- xarray.apply_ufunc : Function used for parallel computation rasterio.features : Used for geometry rasterization Notes ----- Memory usage is optimized for time series data when lazy_load=True by processing in chunks determined by available system memory. The 'xvec' method requires the xvec package to be installed separately. Examples -------- >>> import xarray as xr >>> import geopandas as gpd >>> dataset = xr.open_dataset("temperature.nc") >>> polygons = gpd.read_file("zones.geojson") >>> stats = compute_zonal_stats( ... dataset, ... polygons, ... reducers=["mean", "max"], ... lazy_load=True ... ) Raises ------ ImportError If 'xvec' method is selected but xvec package is not installed. ValueError If invalid method or reducer is specified. DeprecationWarning If deprecated parameters are used. """ # Input validation and deprecation warnings if "label" in kwargs: raise DeprecationWarning( '"label" parameter is deprecated and removed in earthdaily>=0.5. ' "All geometry columns are preserved by default (preserve_columns=True)." ) if "smart_load" in kwargs: import warnings warnings.warn( '"smart_load" will be deprecated in earthdaily>=0.6. ' 'Use "lazy_load" instead (lazy_load=True == smart_load=False).', DeprecationWarning, ) lazy_load = not kwargs["smart_load"] # Clip dataset to geometry bounds dataset = dataset.rio.clip_box(*geometries.to_crs(dataset.rio.crs).total_bounds) # Apply buffer if specified if buffer_meters is not None: geometries = _apply_buffer(geometries, buffer_meters) if method == "numpy": return _compute_numpy_stats( dataset, geometries, lazy_load, max_memory_mb, reducers, all_touched, preserve_columns, **kwargs, ) elif method == "xvec": return _compute_xvec_stats( dataset, geometries, reducers, all_touched, preserve_columns, **kwargs ) else: raise ValueError(f"Unsupported method: {method}")
def _apply_buffer( geometries: gpd.GeoDataFrame, buffer_meters: Union[int, float] ) -> gpd.GeoDataFrame: """Apply buffer to geometries in meters.""" original_crs = geometries.crs geometries = geometries.to_crs({"proj": "cea"}) geometries["geometry_original"] = geometries.geometry geometries.geometry = geometries.buffer(buffer_meters) return geometries.to_crs(original_crs) def _compute_numpy_stats( dataset: xr.Dataset, geometries: gpd.GeoDataFrame, lazy_load: bool, max_memory_mb: Optional[float], reducers: List[str], all_touched: bool, preserve_columns: bool, **kwargs, ) -> xr.Dataset: """Compute zonal statistics using numpy method.""" # Rasterize geometries features, yx_positions = SpatialIndexer.rasterize_geometries( geometries.copy(), dataset, all_touched ) positions = [np.asarray(pos) for pos in yx_positions[1:]] positions = [pos for pos in positions if pos.size > 0] # Process time series if present if "time" in dataset.dims and not lazy_load: time_chunks = MemoryManager.calculate_time_chunks(dataset, max_memory_mb) stats = _process_time_chunks( dataset, positions, reducers, lazy_load, time_chunks ) else: stats = StatisticalOperations.zonal_stats(dataset, positions, reducers) # Format output return _format_numpy_output(stats, features, geometries, reducers, preserve_columns) def _process_time_chunks( dataset: xr.Dataset, positions: List[np.ndarray], reducers: List[str], lazy_load: bool, time_chunks: int, ) -> xr.Dataset: """Process dataset in time chunks to optimize memory usage.""" chunks = [] for time_idx in trange(0, dataset.time.size, time_chunks): end_idx = min(time_idx + time_chunks, dataset.time.size) ds_chunk = dataset.isel(time=slice(time_idx, end_idx)) if not lazy_load: load_start = time.time() ds_chunk = ds_chunk.load() logger.debug( f"Loaded {ds_chunk.time.size} dates in " f"{(time.time()-load_start):.2f}s" ) compute_start = time.time() chunk_stats = StatisticalOperations.zonal_stats(ds_chunk, positions, reducers) logger.debug(f"Computed chunk statistics in {(time.time()-compute_start):.2f}s") chunks.append(chunk_stats) return xr.concat(chunks, dim="time") def _compute_xvec_stats( dataset: xr.Dataset, geometries: gpd.GeoDataFrame, reducers: List[str], all_touched: bool, preserve_columns: bool, **kwargs, ) -> xr.Dataset: """Compute zonal statistics using xvec method. Args: dataset: Input dataset geometries: Input geometries reducers: List of statistical operations to perform all_touched: Whether to include all touched pixels preserve_columns: Whether to preserve geometry columns **kwargs: Additional keyword arguments Returns: xarray.Dataset: Computed statistics Raises: ImportError: If xvec package is not installed """ try: import xvec except ImportError: raise ImportError( "The xvec method requires the xvec package. " "Please install it with: pip install xvec" ) # Compute statistics using xvec stats = dataset.xvec.zonal_stats( geometries.to_crs(dataset.rio.crs).geometry, y_coords="y", x_coords="x", stats=reducers, method="rasterize", all_touched=all_touched, **kwargs, ) # Drop geometry and add as coordinate stats = stats.drop("geometry") stats = stats.assign_coords( geometry=("feature", geometries.geometry.to_wkt(rounding_precision=-1).values) ) # Add index coordinate stats = stats.assign_coords(index=("feature", geometries.index)) stats = stats.set_index(feature=["geometry", "index"]) # Transpose dimensions to match numpy method output stats = stats.transpose("time", "feature", "zonal_statistics") # Preserve additional columns if requested if preserve_columns: stats = _preserve_geometry_columns(stats, geometries) return stats def _format_numpy_output( stats: xr.Dataset, features: np.ndarray, geometries: gpd.GeoDataFrame, reducers: List[str], preserve_columns: bool, ) -> xr.Dataset: """Format numpy statistics output.""" # Set coordinates and metadata stats = stats.assign_coords(zonal_statistics=reducers) stats = stats.rio.write_crs("EPSG:4326") # Process features and create index feature_indices = np.unique(features) feature_indices = feature_indices[feature_indices > 0] index = geometries.index[feature_indices - 1] # Convert geometries to WKT if geometries.crs.to_epsg() != 4326: geometries = geometries.to_crs("EPSG:4326") geometry_wkt = ( geometries.geometry.iloc[feature_indices - 1] .to_wkt(rounding_precision=-1) .values ) # Assign coordinates coords = {"index": ("feature", index), "geometry": ("feature", geometry_wkt)} stats = stats.assign_coords(**coords) stats = stats.set_index(feature=list(coords.keys())) # Preserve additional columns if requested if preserve_columns: stats = _preserve_geometry_columns(stats, geometries) return stats def _preserve_geometry_columns(stats: xr.Dataset, geometries: gpd.GeoDataFrame) -> None: """Preserve geometry columns in output statistics.""" cols = [ col for col in geometries.columns if col != geometries._geometry_column_name ] values = geometries.loc[stats.index.values][cols].values.T for col, val in zip(cols, values): stats = stats.assign_coords({col: ("feature", val)}) feature_index = list(stats["feature"].to_index().names) feature_index.extend(cols) stats = stats.set_index(feature=feature_index) return stats