Source code for earthdaily.earthdatastore.cube_utils._zonal

# -*- coding: utf-8 -*-
"""
Created on Fri Oct 13 09:32:31 2023

@author: nkk
"""

from rasterio import features
import numpy as np
import xarray as xr
import tqdm
import logging
import time
from .preprocessing import rasterize
from scipy.sparse import csr_matrix
from scipy.stats import mode


def _compute_M(data):
    cols = np.arange(data.size)
    return csr_matrix((cols, (data.ravel(), cols)), shape=(data.max() + 1, data.size))


def _indices_sparse(data):
    M = _compute_M(data)
    return [np.unravel_index(row.data, data.shape) for row in M]


def datacube_time_stats(datacube, operations):
    datacube = datacube.groupby("time")
    stats = []
    for operation in operations:
        stat = getattr(datacube, operation)(...)
        stats.append(stat.expand_dims(dim={"stats": [operation]}))
    stats = xr.concat(stats, dim="stats")
    return stats


def _rasterize(gdf, dataset, all_touched=False):
    feats = rasterize(gdf, dataset, all_touched=all_touched)
    yx_pos = _indices_sparse(feats)
    return feats, yx_pos


def _memory_time_chunks(dataset, memory=None):
    import psutil

    if memory is None:
        memory = psutil.virtual_memory().available / 1e6
        logging.debug(f"Hoping to use a maximum memory {memory}Mo.")
    nbytes_per_date = int(dataset.nbytes / 1e6) / dataset.time.size * 3
    max_time_chunks = int(np.arange(0, memory, nbytes_per_date + 0.1).size)
    time_chunks = int(
        dataset.time.size / np.arange(0, dataset.time.size, max_time_chunks).size
    )
    logging.debug(
        f"Mo per date : {nbytes_per_date:0.2f}, total : {(nbytes_per_date*dataset.time.size):0.2f}."
    )
    logging.debug(f"Time chunks : {time_chunks} (on {dataset.time.size} time).")
    return time_chunks


def _zonal_stats_numpy(
    dataset: xr.Dataset, positions, reducers: list = ["mean"], all_touched=False
):
    def _zonal_stats_ufunc(dataset, positions, reducers):
        zs = []
        for idx in range(len(positions)):
            field_stats = []
            for reducer in reducers:
                field_arr = dataset[(...,) + tuple(positions[idx])]
                if reducer == "mode":
                    field_arr = mode(field_arr, axis=-1, nan_policy="omit").mode
                else:
                    func = f"nan{reducer}" if hasattr(np, f"nan{reducer}") else reducer
                    field_arr = getattr(np, func)(field_arr, axis=-1)
                field_stats.append(field_arr)
            field_stats = np.asarray(field_stats)
            zs.append(field_stats)
        zs = np.asarray(zs)
        zs = zs.swapaxes(-1, 0).swapaxes(-1, -2)
        return zs

    dask_ufunc = "parallelized"

    zs = xr.apply_ufunc(
        _zonal_stats_ufunc,
        dataset,
        vectorize=False,
        dask=dask_ufunc,
        input_core_dims=[["y", "x"]],
        output_core_dims=[["feature", "zonal_statistics"]],
        exclude_dims=set(["x", "y"]),
        output_dtypes=[float],
        kwargs=dict(reducers=reducers, positions=positions),
        dask_gufunc_kwargs={
            "allow_rechunk": True,
            "output_sizes": dict(
                feature=len(positions), zonal_statistics=len(reducers)
            ),
        },
    )

    del dataset

    return zs


[docs] def zonal_stats( dataset: xr.Dataset, geoms, method: str = "numpy", smart_load: bool = False, memory: int = None, reducers: list = ["mean"], all_touched=True, label=None, buffer_meters: int | float | None = None, **kwargs, ): """ Xr Zonal stats using np.nan functions. Parameters ---------- dataset : xr.Dataset DESCRIPTION. geoms : TYPE DESCRIPTION. method : str "xvec" or "numpy". The default is "numpy". smart_load : bool Will load in memory the maximum of time and loop on it for "numpy" method. The default is False. memory : int, optional Only for the "numpy" method, by default it will take the maximum memory available. But in some cases it can be too much or too little. The default is None. reducers : list, optional Any np.nan function ("mean" is "np.nanmean"). The default is ['mean']. Yields ------ zs : TYPE DESCRIPTION. """ def _loop_time_chunks(dataset, method, smart_load, time_chunks): logging.debug( f"Batching every {time_chunks} dates ({np.ceil(dataset.time.size/time_chunks).astype(int)} loops)." ) for time_idx in tqdm.trange(0, dataset.time.size, time_chunks): isel_time = np.arange( time_idx, np.min((time_idx + time_chunks, dataset.time.size)) ) ds = dataset.copy().isel(time=isel_time) if smart_load: t0 = time.time() ds = ds.load() logging.debug( f"Subdataset of {ds.time.size} dates loaded in memory in {(time.time()-t0):0.2f}s." ) t0 = time.time() # for method in tqdm.tqdm(["np"]): zs = _zonal_stats_numpy(ds, positions, reducers) zs = zs.load() del ds logging.debug(f"Zonal stats computed in {(time.time()-t0):0.2f}s.") yield zs t_start = time.time() dataset = dataset.rio.clip_box(*geoms.to_crs(dataset.rio.crs).total_bounds) if isinstance(buffer_meters, float | int): input_crs = geoms.crs geoms = geoms.to_crs({"proj": "cea"}) geoms["geometry_original"] = geoms.geometry geoms.geometry = geoms.buffer(buffer_meters) geoms.to_crs(input_crs) if method == "numpy": feats, yx_pos = _rasterize(geoms.copy(), dataset, all_touched=all_touched) positions = [np.asarray(yx_pos[i + 1]) for i in np.arange(geoms.shape[0])] positions = [position for position in positions if position.size > 0] del yx_pos if "time" in dataset.dims and smart_load: time_chunks = _memory_time_chunks(dataset, memory) zs = xr.concat( [ z for z in _loop_time_chunks(dataset, method, smart_load, time_chunks) ], dim="time", ) else: zs = _zonal_stats_numpy(dataset, positions, reducers, **kwargs) zs = zs.assign_coords(zonal_statistics=reducers) zs = zs.rio.write_crs("EPSG:4326") # keep only geom that have been found in the raster f = np.unique(feats) f = f[f > 0] index = geoms.index[f - 1] index = xr.DataArray( index, dims=["feature"], coords={"feature": zs.feature.values} ) # create the WKT geom if isinstance(buffer_meters, float | int): geoms.geometry = geoms["geometry_original"] if geoms.crs.to_epsg() != 4326: geoms = geoms.to_crs("EPSG:4326") geometry = xr.DataArray( geoms.geometry.iloc[list(f - 1)].to_wkt(rounding_precision=-1).values, dims=["feature"], coords={"feature": zs.feature.values}, ) new_coords_kwargs = {"index": index, "geometry": geometry} # add the label if a column is specified if label: label = xr.DataArray( list(geoms[label].iloc[f - 1]), dims=["feature"], coords={"feature": zs.feature.values}, ) new_coords_kwargs["label"] = label zs = zs.assign_coords(**new_coords_kwargs) zs = zs.set_index(feature=list(new_coords_kwargs.keys())) if method == "xvec": import xvec zs = dataset.xvec.zonal_stats( geoms.to_crs(dataset.rio.crs).geometry, y_coords="y", x_coords="x", stats=reducers, method="rasterize", all_touched=all_touched, **kwargs, ) logging.info(f"Zonal stats method {method} tooks {time.time()-t_start}s.") del dataset return zs