Source code for earthdaily.earthdatastore.cube_utils._zonal
"""Zonal statistics computation for geospatial data analysis.This module provides functionality for computing zonal statistics on xarray Datasetsusing various methods including numpy and xvec. It supports parallel processing,memory optimization, and various statistical operations.Example: >>> import xarray as xr >>> import geopandas as gpd >>> dataset = xr.open_dataset("temperature.nc") >>> polygons = gpd.read_file("zones.geojson") >>> stats = zonal_stats(dataset, polygons, reducers=["mean", "max"])"""importloggingimporttimefromtypingimportList,Optional,Tuple,Union,castimportgeopandasasgpdimportnumpyasnptry:importpolarsaspl# type: ignoreexceptImportError:pl=None# type: ignoreimportpsutilimportxarrayasxrfromscipy.sparseimportcsr_matrixfromscipy.statsimportmodefromtqdm.autoimporttrangefrom.preprocessingimportrasterize# Configure logginglogger=logging.getLogger(__name__)classMemoryManager:"""Manages memory allocation for large dataset processing."""@staticmethoddefcalculate_time_chunks(dataset:xr.Dataset,max_memory_mb:Optional[float]=None)->int:"""Calculate optimal time chunks based on available memory. Args: dataset: Input xarray Dataset max_memory_mb: Maximum memory to use in megabytes Returns: int: Optimal number of time chunks """ifmax_memory_mbisNone:max_memory_mb=psutil.virtual_memory().available/1e6logger.info(f"Using maximum available memory: {max_memory_mb:.2f}MB")bytes_per_date=(dataset.nbytes/1e6)/dataset.time.size*3max_chunks=int(np.arange(0,max_memory_mb,bytes_per_date+0.1).size)time_chunks=int(dataset.time.size/np.arange(0,dataset.time.size,max_chunks).size)logger.info(f"Estimated memory per date: {bytes_per_date:.2f}MB. Total: {(bytes_per_date*dataset.time.size):.2f}MB")logger.info(f"Time chunks: {time_chunks} (total time steps: {dataset.time.size})")returntime_chunksclassSpatialIndexer:"""Handles spatial indexing and rasterization operations."""@staticmethoddefcompute_sparse_matrix(data:np.ndarray)->csr_matrix:"""Compute sparse matrix from input data. Args: data: Input numpy array Returns: scipy.sparse.csr_matrix: Computed sparse matrix """cols=np.arange(data.size)returncsr_matrix((cols,(data.ravel(),cols)),shape=(data.max()+1,data.size))@staticmethoddefget_sparse_indices(data:np.ndarray)->List[Tuple[np.ndarray,...]]:"""Get sparse indices from input data. Args: data: Input numpy array Returns: List of index tuples """matrix=SpatialIndexer.compute_sparse_matrix(data)return[np.unravel_index(row.data,data.shape)forrowinmatrix]@staticmethoddefrasterize_geometries(gdf:gpd.GeoDataFrame,dataset:xr.Dataset,all_touched:bool=False,positions:bool=True,)->Union[np.ndarray,Tuple[np.ndarray,List[Tuple[np.ndarray,...]]]]:"""Rasterize geometries to match dataset resolution. Args: gdf: Input GeoDataFrame dataset: Reference dataset for rasterization all_touched: Whether to include all touched pixels positions: Whether to compute position indices Returns: If positions=True, returns Tuple containing features array and positions list Otherwise, returns only the features array """features=rasterize(gdf,dataset,all_touched=all_touched)ifpositions:position_indices=SpatialIndexer.get_sparse_indices(features)returnfeatures,position_indicesreturnfeaturesclassStatisticalOperations:"""Handles statistical computations on spatial data."""@staticmethoddefzonal_stats(dataset:xr.Dataset,positions:np.ndarray,reducers:List[str],method:str="numpy",)->xr.Dataset:"""Compute zonal statistics for given positions using specified reducers. Args: dataset: Input dataset positions: Array of position indices reducers: List of statistical operations to perform method: Computation method to use Returns: xarray.Dataset: Computed statistics Notes: Uses xarray's apply_ufunc for parallel processing and efficient computation """def_zonal_stats_ufunc(data,positions,reducers):"""Inner function for parallel computation of zonal statistics."""zs=[]tf=positions!=0foridxinnp.unique(positions[tf]):field_stats=[]forreducerinreducers:mask=positions==idxfield_arr=data[:,mask]ifreducer=="mode":field_arr=mode(field_arr,axis=-1,nan_policy="omit").modeelse:func=(f"nan{reducer}"ifhasattr(np,f"nan{reducer}")elsereducer)field_arr=getattr(np,func)(field_arr,axis=-1)field_stats.append(field_arr)field_stats=np.asarray(field_stats)zs.append(field_stats)zs=np.asarray(zs)returnzs.swapaxes(-1,0).swapaxes(-1,-2)def_zonal_stats_polars_ufunc(data,positions,reducers):"""Inner function for parallel computation of zonal statistics using polars."""zs=[]tf=positions!=0pol_positions=positions[tf]n_dims=data.shape[0]original_idx=np.arange(np.unique(pol_positions).size)+1# Process each dimension separatelyfordiminrange(n_dims):# Create DataFrame for dimensiondf=pl.DataFrame({"polygon_id":pol_positions,f"dim_{dim}":data[dim,...][tf]})# Compute statistics using Polars groupbystats=(df.lazy().drop_nans().group_by("polygon_id").agg([getattr(pl.col(f"dim_{dim}"),reducer)().alias(reducer)forreducerinreducersifreducer!="mode"]))# Handle mode separately if neededif"mode"inreducers:raiseNotImplementedError("mode is not yet implemented")# =============================================================================# stats_mode = (# df.lazy()# .drop_nans()# .group_by("polygon_id")# .agg(# getattr(pl.col(f"dim_{dim}"), "mode")()# .alias("mode")# .first()# )# )# stats = stats.collect().join(stats_mode.collect(), on="polygon_id")# stats_array = stats.sort("polygon_id").select(reducers).to_numpy()# =============================================================================else:stats_array=stats.collect().sort("polygon_id").to_numpy()idx=stats_array[:,0].astype(np.int64)stats_array=stats_array[:,1:]missing_idx=np.setdiff1d(original_idx,idx)ifmissing_idx.size:# Create a mask for the final arrayresult=np.full((len(original_idx),len(reducers)),np.nan)# Find positions of idx in original_idx for mappingidx_positions=np.searchsorted(original_idx,idx)# Insert the actual data values at the correct positionsresult[idx_positions-1,:]=stats_arraystats_array=resultzs.append(stats_array)returnnp.asarray(zs)methods={"numpy":_zonal_stats_ufunc,"polars":_zonal_stats_polars_ufunc}# Apply the function using xarray's parallel processing capabilitiesreturnxr.apply_ufunc(methods.get(method,_zonal_stats_ufunc),dataset,vectorize=False,dask="parallelized",input_core_dims=[["y","x"]],output_core_dims=[["feature","zonal_statistics"]],exclude_dims=set(["x","y"]),output_dtypes=[float],kwargs=dict(reducers=reducers,positions=positions),dask_gufunc_kwargs={"allow_rechunk":True,"output_sizes":dict(feature=np.unique(positions[positions>0]).size,zonal_statistics=len(reducers),),},)
[docs]defzonal_stats(dataset:xr.Dataset,geometries:Union[gpd.GeoDataFrame,gpd.GeoSeries],method:str="numpy",lazy_load:bool=True,max_memory_mb:Optional[float]=None,reducers:List[str]=["mean"],all_touched:bool=True,preserve_columns:bool=True,buffer_meters:Optional[Union[int,float]]=None,**kwargs,)->xr.Dataset:"""Calculate zonal statistics for xarray Dataset based on geometric boundaries. This function computes statistical summaries of Dataset values within each geometry's zone, supporting parallel processing through xarray's apply_ufunc and multiple computation methods. Parameters ---------- dataset : xarray.Dataset Input dataset containing variables for statistics computation. geometries : Union[geopandas.GeoDataFrame, geopandas.GeoSeries] Geometries defining the zones for statistics calculation. method : str, optional Method for computation. Options: - 'numpy': Uses numpy functions with parallel processing - 'polars': Uses polars for data manipulation - 'xvec': Uses xvec library (must be installed) Default is 'numpy'. lazy_load : bool, optional If True, optimizes memory usage by loading chunks of data for 'numpy' method. Default is True. max_memory_mb : float, optional Maximum memory to use in megabytes. If None, uses maximum available memory. Default is None. reducers : list[str], optional List of statistical operations to perform. Functions should be numpy nan-functions (e.g., 'mean' uses np.nanmean). Default is ['mean']. all_touched : bool, optional If True, includes all pixels touched by geometries in computation. Default is True. preserve_columns : bool, optional If True, preserves all columns from input geometries in output. Default is True. buffer_meters : Union[int, float, None], optional Buffer distance in meters to apply to geometries before computation. Default is None. **kwargs : dict Additional keyword arguments passed to underlying computation functions. Returns ------- xarray.Dataset Dataset containing computed statistics with dimensions: - time (if present in input) - feature (number of geometries) - zonal_statistics (number of reducers) Additional coordinates include geometry WKT and preserved columns if requested. See Also -------- xarray.apply_ufunc : Function used for parallel computation rasterio.features : Used for geometry rasterization Notes ----- Memory usage is optimized for time series data when lazy_load=True by processing in chunks determined by available system memory. The 'xvec' method requires the xvec package to be installed separately. Examples -------- >>> import xarray as xr >>> import geopandas as gpd >>> dataset = xr.open_dataset("temperature.nc") >>> polygons = gpd.read_file("zones.geojson") >>> stats = compute_zonal_stats( ... dataset, ... polygons, ... reducers=["mean", "max"], ... lazy_load=True ... ) Raises ------ ImportError If 'xvec' method is selected but xvec package is not installed. ValueError If invalid method or reducer is specified. DeprecationWarning If deprecated parameters are used. """# Input validation and deprecation warningsif"label"inkwargs:raiseDeprecationWarning('"label" parameter is deprecated and removed in earthdaily>=0.5. '"All geometry columns are preserved by default (preserve_columns=True).")if"smart_load"inkwargs:importwarningswarnings.warn('"smart_load" will be deprecated in earthdaily>=0.6. ''Use "lazy_load" instead (lazy_load=True == smart_load=False).',DeprecationWarning,)lazy_load=notkwargs["smart_load"]# Clip dataset to geometry boundsdataset=dataset.rio.clip_box(*geometries.to_crs(dataset.rio.crs).total_bounds)# Apply buffer if specifiedifbuffer_metersisnotNone:geometries=_apply_buffer(geometries,buffer_meters)ifmethod=="polars":ifplisNone:raiseImportError("The polars method requires the polars package. ""Please install it with: pip install polars")return_compute_polars_stats(dataset,geometries,lazy_load,max_memory_mb,reducers,all_touched,preserve_columns,**kwargs,)elifmethod=="numpy":return_compute_numpy_stats(dataset,geometries,lazy_load,max_memory_mb,reducers,all_touched,preserve_columns,**kwargs,)elifmethod=="xvec":return_compute_xvec_stats(dataset,geometries,reducers,all_touched,preserve_columns,**kwargs)else:raiseValueError(f"Unsupported method: {method}")
def_apply_buffer(geometries:gpd.GeoDataFrame,buffer_meters:Union[int,float])->gpd.GeoDataFrame:"""Apply buffer to geometries in meters."""original_crs=geometries.crsgeometries=geometries.to_crs({"proj":"cea"})geometries["geometry_original"]=geometries.geometrygeometries.geometry=geometries.buffer(buffer_meters)returngeometries.to_crs(original_crs)def_compute_polars_stats(dataset:xr.Dataset,geometries:gpd.GeoDataFrame,lazy_load:bool,max_memory_mb:Optional[float],reducers:List[str],all_touched:bool,preserve_columns:bool,**kwargs,)->xr.Dataset:"""Compute zonal statistics using polars method. Args: dataset: Input dataset geometries: Input geometries lazy_load: Whether to optimize memory usage max_memory_mb: Maximum memory to use reducers: List of statistical operations all_touched: Whether to include all touched pixels preserve_columns: Whether to preserve geometry columns **kwargs: Additional keyword arguments Returns: xarray.Dataset: Computed statistics """# Rasterize geometriespositions=cast(np.ndarray,SpatialIndexer.rasterize_geometries(geometries.copy(),dataset,all_touched,positions=False),)stats=StatisticalOperations.zonal_stats(dataset=dataset,positions=positions,reducers=reducers,method="polars")# Format outputreturn_format_numpy_output(stats,positions,geometries,reducers,preserve_columns)def_compute_numpy_stats(dataset:xr.Dataset,geometries:gpd.GeoDataFrame,lazy_load:bool,max_memory_mb:Optional[float],reducers:List[str],all_touched:bool,preserve_columns:bool,**kwargs,)->xr.Dataset:"""Compute zonal statistics using numpy method. Args: dataset: Input dataset geometries: Input geometries lazy_load: Whether to optimize memory usage max_memory_mb: Maximum memory to use reducers: List of statistical operations all_touched: Whether to include all touched pixels preserve_columns: Whether to preserve geometry columns **kwargs: Additional keyword arguments Returns: xarray.Dataset: Computed statistics """# Rasterize geometriespositions=cast(np.ndarray,SpatialIndexer.rasterize_geometries(geometries.copy(),dataset,all_touched,positions=False),)# Process time series if presentif"time"indataset.dimsandnotlazy_load:time_chunks=MemoryManager.calculate_time_chunks(dataset,max_memory_mb)stats=_process_time_chunks(dataset,positions,reducers,lazy_load,time_chunks)else:stats=StatisticalOperations.zonal_stats(dataset=dataset,positions=positions,reducers=reducers,method="numpy")# Format outputreturn_format_numpy_output(stats,positions,geometries,reducers,preserve_columns)def_process_time_chunks(dataset:xr.Dataset,positions:np.ndarray,reducers:List[str],lazy_load:bool,time_chunks:int,)->xr.Dataset:"""Process dataset in time chunks to optimize memory usage. Args: dataset: Input dataset positions: Position indices reducers: List of statistical operations lazy_load: Whether to optimize memory usage time_chunks: Number of time chunks Returns: xarray.Dataset: Computed statistics """chunks=[]fortime_idxintrange(0,dataset.time.size,time_chunks):end_idx=min(time_idx+time_chunks,dataset.time.size)ds_chunk=dataset.isel(time=slice(time_idx,end_idx))ifnotlazy_load:load_start=time.time()ds_chunk=ds_chunk.load()logger.debug(f"Loaded {ds_chunk.time.size} dates in "f"{(time.time()-load_start):.2f}s")compute_start=time.time()chunk_stats=StatisticalOperations.zonal_stats(ds_chunk,positions,reducers)logger.debug(f"Computed chunk statistics in {(time.time()-compute_start):.2f}s")chunks.append(chunk_stats)returnxr.concat(chunks,dim="time")def_compute_xvec_stats(dataset:xr.Dataset,geometries:gpd.GeoDataFrame,reducers:List[str],all_touched:bool,preserve_columns:bool,**kwargs,)->xr.Dataset:"""Compute zonal statistics using xvec method. Args: dataset: Input dataset geometries: Input geometries reducers: List of statistical operations to perform all_touched: Whether to include all touched pixels preserve_columns: Whether to preserve geometry columns **kwargs: Additional keyword arguments Returns: xarray.Dataset: Computed statistics Raises: ImportError: If xvec package is not installed """try:importxvec# noqaexceptImportError:raiseImportError("The xvec method requires the xvec package. ""Please install it with: pip install xvec")# Compute statistics using xvecstats=dataset.xvec.zonal_stats(geometries.to_crs(dataset.rio.crs).geometry,y_coords="y",x_coords="x",stats=reducers,method="rasterize",all_touched=all_touched,**kwargs,)# Drop geometry and add as coordinatestats=stats.drop("geometry")stats=stats.assign_coords(geometry=("feature",geometries.geometry.to_wkt(rounding_precision=-1).values))# Add index coordinatestats=stats.assign_coords(index=("feature",geometries.index))stats=stats.set_index(feature=["geometry","index"])# Transpose dimensions to match numpy method outputstats=stats.transpose("time","feature","zonal_statistics")# Preserve additional columns if requestedifpreserve_columns:stats=_preserve_geometry_columns(stats,geometries)returnstatsdef_format_numpy_output(stats:xr.Dataset,features:np.ndarray,geometries:gpd.GeoDataFrame,reducers:List[str],preserve_columns:bool,)->xr.Dataset:"""Format numpy statistics output. Args: stats: Computed statistics features: Features array geometries: Input geometries reducers: List of statistical operations preserve_columns: Whether to preserve geometry columns Returns: xarray.Dataset: Formatted statistics """# Set coordinates and metadatastats=stats.assign_coords(zonal_statistics=reducers)stats=stats.rio.write_crs("EPSG:4326")# Process features and create indexfeature_indices=np.unique(features)feature_indices=feature_indices[feature_indices>0]index=geometries.index[feature_indices-1]# Convert geometries to WKTifgeometries.crs.to_epsg()!=4326:geometries=geometries.to_crs("EPSG:4326")geometry_wkt=(geometries.geometry.iloc[feature_indices-1].to_wkt(rounding_precision=-1).values)# Assign coordinatescoords={"index":(["feature"],index),"geometry":(["feature"],geometry_wkt)}stats=stats.assign_coords(coords)stats=stats.set_index(feature=list(coords.keys()))# Preserve additional columns if requestedifpreserve_columns:stats=_preserve_geometry_columns(stats,geometries)returnstatsdef_preserve_geometry_columns(stats:xr.Dataset,geometries:gpd.GeoDataFrame)->xr.Dataset:"""Preserve geometry columns in output statistics. Args: stats: Computed statistics geometries: Input geometries Returns: xarray.Dataset: Statistics with preserved columns """cols=[colforcolingeometries.columnsifcol!=geometries._geometry_column_name]values=geometries.loc[stats.index.values][cols].values.Tforcol,valinzip(cols,values):stats=stats.assign_coords({col:("feature",val)})feature_index=list(stats["feature"].to_index().names)feature_index.extend(cols)stats=stats.set_index(feature=feature_index)returnstats