"""Zonal statistics computation for geospatial data analysis.
This module provides functionality for computing zonal statistics on xarray Datasets
using various methods including numpy and xvec. It supports parallel processing,
memory optimization, and various statistical operations.
Example:
>>> import xarray as xr
>>> import geopandas as gpd
>>> dataset = xr.open_dataset("temperature.nc")
>>> polygons = gpd.read_file("zones.geojson")
>>> stats = zonal_stats(dataset, polygons, reducers=["mean", "max"])
"""
import logging
import time
from typing import List, Optional, Tuple, Union, cast
import geopandas as gpd
import numpy as np
try:
import polars as pl # type: ignore
except ImportError:
pl = None # type: ignore
import psutil
import xarray as xr
from scipy.sparse import csr_matrix
from scipy.stats import mode
from tqdm.auto import trange
from .preprocessing import rasterize
# Configure logging
logger = logging.getLogger(__name__)
class MemoryManager:
"""Manages memory allocation for large dataset processing."""
@staticmethod
def calculate_time_chunks(
dataset: xr.Dataset, max_memory_mb: Optional[float] = None
) -> int:
"""Calculate optimal time chunks based on available memory.
Args:
dataset: Input xarray Dataset
max_memory_mb: Maximum memory to use in megabytes
Returns:
int: Optimal number of time chunks
"""
if max_memory_mb is None:
max_memory_mb = psutil.virtual_memory().available / 1e6
logger.info(f"Using maximum available memory: {max_memory_mb:.2f}MB")
bytes_per_date = (dataset.nbytes / 1e6) / dataset.time.size * 3
max_chunks = int(np.arange(0, max_memory_mb, bytes_per_date + 0.1).size)
time_chunks = int(
dataset.time.size / np.arange(0, dataset.time.size, max_chunks).size
)
logger.info(
f"Estimated memory per date: {bytes_per_date:.2f}MB. Total: {(bytes_per_date * dataset.time.size):.2f}MB"
)
logger.info(
f"Time chunks: {time_chunks} (total time steps: {dataset.time.size})"
)
return time_chunks
class SpatialIndexer:
"""Handles spatial indexing and rasterization operations."""
@staticmethod
def compute_sparse_matrix(data: np.ndarray) -> csr_matrix:
"""Compute sparse matrix from input data.
Args:
data: Input numpy array
Returns:
scipy.sparse.csr_matrix: Computed sparse matrix
"""
cols = np.arange(data.size)
return csr_matrix(
(cols, (data.ravel(), cols)), shape=(data.max() + 1, data.size)
)
@staticmethod
def get_sparse_indices(data: np.ndarray) -> List[Tuple[np.ndarray, ...]]:
"""Get sparse indices from input data.
Args:
data: Input numpy array
Returns:
List of index tuples
"""
matrix = SpatialIndexer.compute_sparse_matrix(data)
return [np.unravel_index(row.data, data.shape) for row in matrix]
@staticmethod
def rasterize_geometries(
gdf: gpd.GeoDataFrame,
dataset: xr.Dataset,
all_touched: bool = False,
positions: bool = True,
) -> Union[np.ndarray, Tuple[np.ndarray, List[Tuple[np.ndarray, ...]]]]:
"""Rasterize geometries to match dataset resolution.
Args:
gdf: Input GeoDataFrame
dataset: Reference dataset for rasterization
all_touched: Whether to include all touched pixels
positions: Whether to compute position indices
Returns:
If positions=True, returns Tuple containing features array and positions list
Otherwise, returns only the features array
"""
features = rasterize(gdf, dataset, all_touched=all_touched)
if positions:
position_indices = SpatialIndexer.get_sparse_indices(features)
return features, position_indices
return features
class StatisticalOperations:
"""Handles statistical computations on spatial data."""
@staticmethod
def zonal_stats(
dataset: xr.Dataset,
positions: np.ndarray,
reducers: List[str],
method: str = "numpy",
) -> xr.Dataset:
"""Compute zonal statistics for given positions using specified reducers.
Args:
dataset: Input dataset
positions: Array of position indices
reducers: List of statistical operations to perform
method: Computation method to use
Returns:
xarray.Dataset: Computed statistics
Notes:
Uses xarray's apply_ufunc for parallel processing and efficient computation
"""
def _zonal_stats_ufunc(data, positions, reducers):
"""Inner function for parallel computation of zonal statistics."""
zs = []
tf = positions != 0
for idx in np.unique(positions[tf]):
field_stats = []
for reducer in reducers:
mask = positions == idx
field_arr = data[:, mask]
if reducer == "mode":
field_arr = mode(field_arr, axis=-1, nan_policy="omit").mode
else:
func = (
f"nan{reducer}" if hasattr(np, f"nan{reducer}") else reducer
)
field_arr = getattr(np, func)(field_arr, axis=-1)
field_stats.append(field_arr)
field_stats = np.asarray(field_stats)
zs.append(field_stats)
zs = np.asarray(zs)
return zs.swapaxes(-1, 0).swapaxes(-1, -2)
def _zonal_stats_polars_ufunc(data, positions, reducers):
"""Inner function for parallel computation of zonal statistics using polars."""
zs = []
tf = positions != 0
pol_positions = positions[tf]
n_dims = data.shape[0]
original_idx = np.arange(np.unique(pol_positions).size) + 1
# Process each dimension separately
for dim in range(n_dims):
# Create DataFrame for dimension
df = pl.DataFrame(
{"polygon_id": pol_positions, f"dim_{dim}": data[dim, ...][tf]}
)
# Compute statistics using Polars groupby
stats = (
df.lazy()
.drop_nans()
.group_by("polygon_id")
.agg(
[
getattr(pl.col(f"dim_{dim}"), reducer)().alias(reducer)
for reducer in reducers
if reducer != "mode"
]
)
)
# Handle mode separately if needed
if "mode" in reducers:
raise NotImplementedError("mode is not yet implemented")
# =============================================================================
# stats_mode = (
# df.lazy()
# .drop_nans()
# .group_by("polygon_id")
# .agg(
# getattr(pl.col(f"dim_{dim}"), "mode")()
# .alias("mode")
# .first()
# )
# )
# stats = stats.collect().join(stats_mode.collect(), on="polygon_id")
# stats_array = stats.sort("polygon_id").select(reducers).to_numpy()
# =============================================================================
else:
stats_array = stats.collect().sort("polygon_id").to_numpy()
idx = stats_array[:, 0].astype(np.int64)
stats_array = stats_array[:, 1:]
missing_idx = np.setdiff1d(original_idx, idx)
if missing_idx.size:
# Create a mask for the final array
result = np.full((len(original_idx), len(reducers)), np.nan)
# Find positions of idx in original_idx for mapping
idx_positions = np.searchsorted(original_idx, idx)
# Insert the actual data values at the correct positions
result[idx_positions - 1, :] = stats_array
stats_array = result
zs.append(stats_array)
return np.asarray(zs)
methods = {"numpy": _zonal_stats_ufunc, "polars": _zonal_stats_polars_ufunc}
# Apply the function using xarray's parallel processing capabilities
return xr.apply_ufunc(
methods.get(method, _zonal_stats_ufunc),
dataset,
vectorize=False,
dask="parallelized",
input_core_dims=[["y", "x"]],
output_core_dims=[["feature", "zonal_statistics"]],
exclude_dims=set(["x", "y"]),
output_dtypes=[float],
kwargs=dict(reducers=reducers, positions=positions),
dask_gufunc_kwargs={
"allow_rechunk": True,
"output_sizes": dict(
feature=np.unique(positions[positions > 0]).size,
zonal_statistics=len(reducers),
),
},
)
[docs]
def zonal_stats(
dataset: xr.Dataset,
geometries: Union[gpd.GeoDataFrame, gpd.GeoSeries],
method: str = "numpy",
lazy_load: bool = True,
max_memory_mb: Optional[float] = None,
reducers: List[str] = ["mean"],
all_touched: bool = True,
preserve_columns: bool = True,
buffer_meters: Optional[Union[int, float]] = None,
**kwargs,
) -> xr.Dataset:
"""Calculate zonal statistics for xarray Dataset based on geometric boundaries.
This function computes statistical summaries of Dataset values within each geometry's zone,
supporting parallel processing through xarray's apply_ufunc and multiple computation methods.
Parameters
----------
dataset : xarray.Dataset
Input dataset containing variables for statistics computation.
geometries : Union[geopandas.GeoDataFrame, geopandas.GeoSeries]
Geometries defining the zones for statistics calculation.
method : str, optional
Method for computation. Options:
- 'numpy': Uses numpy functions with parallel processing
- 'polars': Uses polars for data manipulation
- 'xvec': Uses xvec library (must be installed)
Default is 'numpy'.
lazy_load : bool, optional
If True, optimizes memory usage by loading chunks of data for 'numpy' method.
Default is True.
max_memory_mb : float, optional
Maximum memory to use in megabytes. If None, uses maximum available memory.
Default is None.
reducers : list[str], optional
List of statistical operations to perform. Functions should be numpy nan-functions
(e.g., 'mean' uses np.nanmean). Default is ['mean'].
all_touched : bool, optional
If True, includes all pixels touched by geometries in computation.
Default is True.
preserve_columns : bool, optional
If True, preserves all columns from input geometries in output.
Default is True.
buffer_meters : Union[int, float, None], optional
Buffer distance in meters to apply to geometries before computation.
Default is None.
**kwargs : dict
Additional keyword arguments passed to underlying computation functions.
Returns
-------
xarray.Dataset
Dataset containing computed statistics with dimensions:
- time (if present in input)
- feature (number of geometries)
- zonal_statistics (number of reducers)
Additional coordinates include geometry WKT and preserved columns if requested.
See Also
--------
xarray.apply_ufunc : Function used for parallel computation
rasterio.features : Used for geometry rasterization
Notes
-----
Memory usage is optimized for time series data when lazy_load=True by processing
in chunks determined by available system memory.
The 'xvec' method requires the xvec package to be installed separately.
Examples
--------
>>> import xarray as xr
>>> import geopandas as gpd
>>> dataset = xr.open_dataset("temperature.nc")
>>> polygons = gpd.read_file("zones.geojson")
>>> stats = compute_zonal_stats(
... dataset,
... polygons,
... reducers=["mean", "max"],
... lazy_load=True
... )
Raises
------
ImportError
If 'xvec' method is selected but xvec package is not installed.
ValueError
If invalid method or reducer is specified.
DeprecationWarning
If deprecated parameters are used.
"""
# Input validation and deprecation warnings
if "label" in kwargs:
raise DeprecationWarning(
'"label" parameter is deprecated and removed in earthdaily>=0.5. '
"All geometry columns are preserved by default (preserve_columns=True)."
)
if "smart_load" in kwargs:
import warnings
warnings.warn(
'"smart_load" will be deprecated in earthdaily>=0.6. '
'Use "lazy_load" instead (lazy_load=True == smart_load=False).',
DeprecationWarning,
)
lazy_load = not kwargs["smart_load"]
# Clip dataset to geometry bounds
dataset = dataset.rio.clip_box(*geometries.to_crs(dataset.rio.crs).total_bounds)
# Apply buffer if specified
if buffer_meters is not None:
geometries = _apply_buffer(geometries, buffer_meters)
if method == "polars":
if pl is None:
raise ImportError(
"The polars method requires the polars package. "
"Please install it with: pip install polars"
)
return _compute_polars_stats(
dataset,
geometries,
lazy_load,
max_memory_mb,
reducers,
all_touched,
preserve_columns,
**kwargs,
)
elif method == "numpy":
return _compute_numpy_stats(
dataset,
geometries,
lazy_load,
max_memory_mb,
reducers,
all_touched,
preserve_columns,
**kwargs,
)
elif method == "xvec":
return _compute_xvec_stats(
dataset, geometries, reducers, all_touched, preserve_columns, **kwargs
)
else:
raise ValueError(f"Unsupported method: {method}")
def _apply_buffer(
geometries: gpd.GeoDataFrame, buffer_meters: Union[int, float]
) -> gpd.GeoDataFrame:
"""Apply buffer to geometries in meters."""
original_crs = geometries.crs
geometries = geometries.to_crs({"proj": "cea"})
geometries["geometry_original"] = geometries.geometry
geometries.geometry = geometries.buffer(buffer_meters)
return geometries.to_crs(original_crs)
def _compute_polars_stats(
dataset: xr.Dataset,
geometries: gpd.GeoDataFrame,
lazy_load: bool,
max_memory_mb: Optional[float],
reducers: List[str],
all_touched: bool,
preserve_columns: bool,
**kwargs,
) -> xr.Dataset:
"""Compute zonal statistics using polars method.
Args:
dataset: Input dataset
geometries: Input geometries
lazy_load: Whether to optimize memory usage
max_memory_mb: Maximum memory to use
reducers: List of statistical operations
all_touched: Whether to include all touched pixels
preserve_columns: Whether to preserve geometry columns
**kwargs: Additional keyword arguments
Returns:
xarray.Dataset: Computed statistics
"""
# Rasterize geometries
positions = cast(
np.ndarray,
SpatialIndexer.rasterize_geometries(
geometries.copy(), dataset, all_touched, positions=False
),
)
stats = StatisticalOperations.zonal_stats(
dataset=dataset, positions=positions, reducers=reducers, method="polars"
)
# Format output
return _format_numpy_output(
stats, positions, geometries, reducers, preserve_columns
)
def _compute_numpy_stats(
dataset: xr.Dataset,
geometries: gpd.GeoDataFrame,
lazy_load: bool,
max_memory_mb: Optional[float],
reducers: List[str],
all_touched: bool,
preserve_columns: bool,
**kwargs,
) -> xr.Dataset:
"""Compute zonal statistics using numpy method.
Args:
dataset: Input dataset
geometries: Input geometries
lazy_load: Whether to optimize memory usage
max_memory_mb: Maximum memory to use
reducers: List of statistical operations
all_touched: Whether to include all touched pixels
preserve_columns: Whether to preserve geometry columns
**kwargs: Additional keyword arguments
Returns:
xarray.Dataset: Computed statistics
"""
# Rasterize geometries
positions = cast(
np.ndarray,
SpatialIndexer.rasterize_geometries(
geometries.copy(), dataset, all_touched, positions=False
),
)
# Process time series if present
if "time" in dataset.dims and not lazy_load:
time_chunks = MemoryManager.calculate_time_chunks(dataset, max_memory_mb)
stats = _process_time_chunks(
dataset, positions, reducers, lazy_load, time_chunks
)
else:
stats = StatisticalOperations.zonal_stats(
dataset=dataset, positions=positions, reducers=reducers, method="numpy"
)
# Format output
return _format_numpy_output(
stats, positions, geometries, reducers, preserve_columns
)
def _process_time_chunks(
dataset: xr.Dataset,
positions: np.ndarray,
reducers: List[str],
lazy_load: bool,
time_chunks: int,
) -> xr.Dataset:
"""Process dataset in time chunks to optimize memory usage.
Args:
dataset: Input dataset
positions: Position indices
reducers: List of statistical operations
lazy_load: Whether to optimize memory usage
time_chunks: Number of time chunks
Returns:
xarray.Dataset: Computed statistics
"""
chunks = []
for time_idx in trange(0, dataset.time.size, time_chunks):
end_idx = min(time_idx + time_chunks, dataset.time.size)
ds_chunk = dataset.isel(time=slice(time_idx, end_idx))
if not lazy_load:
load_start = time.time()
ds_chunk = ds_chunk.load()
logger.debug(
f"Loaded {ds_chunk.time.size} dates in "
f"{(time.time() - load_start):.2f}s"
)
compute_start = time.time()
chunk_stats = StatisticalOperations.zonal_stats(ds_chunk, positions, reducers)
logger.debug(
f"Computed chunk statistics in {(time.time() - compute_start):.2f}s"
)
chunks.append(chunk_stats)
return xr.concat(chunks, dim="time")
def _compute_xvec_stats(
dataset: xr.Dataset,
geometries: gpd.GeoDataFrame,
reducers: List[str],
all_touched: bool,
preserve_columns: bool,
**kwargs,
) -> xr.Dataset:
"""Compute zonal statistics using xvec method.
Args:
dataset: Input dataset
geometries: Input geometries
reducers: List of statistical operations to perform
all_touched: Whether to include all touched pixels
preserve_columns: Whether to preserve geometry columns
**kwargs: Additional keyword arguments
Returns:
xarray.Dataset: Computed statistics
Raises:
ImportError: If xvec package is not installed
"""
try:
import xvec # noqa
except ImportError:
raise ImportError(
"The xvec method requires the xvec package. "
"Please install it with: pip install xvec"
)
# Compute statistics using xvec
stats = dataset.xvec.zonal_stats(
geometries.to_crs(dataset.rio.crs).geometry,
y_coords="y",
x_coords="x",
stats=reducers,
method="rasterize",
all_touched=all_touched,
**kwargs,
)
# Drop geometry and add as coordinate
stats = stats.drop("geometry")
stats = stats.assign_coords(
geometry=("feature", geometries.geometry.to_wkt(rounding_precision=-1).values)
)
# Add index coordinate
stats = stats.assign_coords(index=("feature", geometries.index))
stats = stats.set_index(feature=["geometry", "index"])
# Transpose dimensions to match numpy method output
stats = stats.transpose("time", "feature", "zonal_statistics")
# Preserve additional columns if requested
if preserve_columns:
stats = _preserve_geometry_columns(stats, geometries)
return stats
def _format_numpy_output(
stats: xr.Dataset,
features: np.ndarray,
geometries: gpd.GeoDataFrame,
reducers: List[str],
preserve_columns: bool,
) -> xr.Dataset:
"""Format numpy statistics output.
Args:
stats: Computed statistics
features: Features array
geometries: Input geometries
reducers: List of statistical operations
preserve_columns: Whether to preserve geometry columns
Returns:
xarray.Dataset: Formatted statistics
"""
# Set coordinates and metadata
stats = stats.assign_coords(zonal_statistics=reducers)
stats = stats.rio.write_crs("EPSG:4326")
# Process features and create index
feature_indices = np.unique(features)
feature_indices = feature_indices[feature_indices > 0]
index = geometries.index[feature_indices - 1]
# Convert geometries to WKT
if geometries.crs.to_epsg() != 4326:
geometries = geometries.to_crs("EPSG:4326")
geometry_wkt = (
geometries.geometry.iloc[feature_indices - 1]
.to_wkt(rounding_precision=-1)
.values
)
# Assign coordinates
coords = {"index": (["feature"], index), "geometry": (["feature"], geometry_wkt)}
stats = stats.assign_coords(coords)
stats = stats.set_index(feature=list(coords.keys()))
# Preserve additional columns if requested
if preserve_columns:
stats = _preserve_geometry_columns(stats, geometries)
return stats
def _preserve_geometry_columns(
stats: xr.Dataset, geometries: gpd.GeoDataFrame
) -> xr.Dataset:
"""Preserve geometry columns in output statistics.
Args:
stats: Computed statistics
geometries: Input geometries
Returns:
xarray.Dataset: Statistics with preserved columns
"""
cols = [
col for col in geometries.columns if col != geometries._geometry_column_name
]
values = geometries.loc[stats.index.values][cols].values.T
for col, val in zip(cols, values):
stats = stats.assign_coords({col: ("feature", val)})
feature_index = list(stats["feature"].to_index().names)
feature_index.extend(cols)
stats = stats.set_index(feature=feature_index)
return stats