from collections import defaultdict
import logging
import pandas as pd
import geopandas as gpd
import numpy as np
import xarray as xr
from rasterio.enums import Resampling
from shapely.geometry import box
from .geometry_manager import GeometryManager
from ._zonal import zonal_stats
from .harmonizer import Harmonizer
from .asset_mapper import AssetMapper
import rioxarray
from functools import wraps
import json
from typing import Callable
__all__ = ["GeometryManager", "rioxarray", "zonal_stats", "zonal_stats_numpy"]
_auto_mask_order = ["cloudmask", "ag_cloud_mask", "native"]
def _datacube_masks(method: Callable) -> Callable:
"""
Decorator to handle automatic mask selection and application.
This decorator provides a flexible way to apply masks to a datacube,
with an 'auto' mode that tries multiple mask options.
Parameters
----------
method : Callable
The method to be wrapped with mask application logic.
Returns
-------
Callable
A wrapped method with enhanced mask handling capabilities.
"""
@wraps(method)
def wrapper(self, *args, **kwargs):
# Handle mask selection
mask_with = kwargs.get("mask_with", None)
# If 'auto' is specified, use the predefined auto mask order
if isinstance(mask_with, str) and mask_with == "auto":
mask_with = _auto_mask_order
# If mask_with is a list, try each mask sequentially
if isinstance(mask_with, list):
kwargs.pop("mask_with", None)
for mask in mask_with:
try:
datacube = method(self, mask_with=mask, *args, **kwargs)
return datacube
except Exception as error:
# If this is the last mask, re-raise the exception
if mask == mask_with[-1]:
raise error
# If no special mask handling is needed, call the method directly
return method(self, *args, **kwargs)
return wrapper
def _datacubes(method: Callable) -> Callable:
"""
Decorator to handle multiple collections and create meta-datacubes.
This decorator provides logic for processing multiple collections,
allowing creation of meta-datacubes when multiple collections are provided.
Parameters
----------
method : Callable
The method to be wrapped with multi-collection handling logic.
Returns
-------
Callable
A wrapped method with enhanced multi-collection processing capabilities.
"""
@wraps(method)
@_datacube_masks
def wrapper(self, *args, **kwargs):
# Determine collections from args or kwargs
collections = kwargs.get("collections", args[0] if len(args) > 0 else None)
# If multiple collections are provided, process them separately
if isinstance(collections, list) and len(collections) > 1:
# Remove collections from kwargs or args
if "collections" in kwargs:
kwargs.pop("collections")
else:
args = args[1:]
# Process each collection
datacubes = []
for idx, collection in enumerate(collections):
# Create datacube for each collection
datacube = method(self, collections=collection, *args, **kwargs)
# Use the first datacube's geobox for subsequent datacubes
if idx == 0:
kwargs["geobox"] = datacube.odc.geobox
datacubes.append(datacube)
# Combine datacubes into a meta-datacube
return metacube(*datacubes)
# If only one collection, process normally
return method(self, *args, **kwargs)
return wrapper
def _match_xy_dims(src, dst, resampling=Resampling.nearest):
if (src.sizes["x"], src.sizes["y"]) != (dst.sizes["x"], dst.sizes["y"]):
src = src.rio.reproject_match(dst, resampling=resampling)
return src
def _bbox_to_intersects(bbox):
if isinstance(bbox, str):
bbox = [float(i) for i in bbox.split(",")]
return gpd.GeoDataFrame(geometry=[box(*bbox)], crs="EPSG:4326")
def _apply_nodata(ds, nodata_assets: dict):
for asset, nodata in nodata_assets.items():
ds[asset].rio.set_nodata(nodata)
ds[asset] = ds[asset].where(ds[asset] != nodata)
return ds
def _autofix_unfrozen_coords_dtype(ds):
attrs = {c: ds.coords[c].data.tolist() for c in ds.coords if c not in ds.sizes}
# force str
for attr in attrs:
if not isinstance(attrs[attr], (str, int, float, np.ndarray, list, tuple)):
ds.coords[attr] = str(attrs[attr])
ds.coords[attr] = ds.coords[attr].astype(str)
return ds
def _cube_odc(
items_collection,
assets=None,
times=None,
dtype="float32",
properties=False,
**kwargs,
):
from odc import stac
if "epsg" in kwargs:
kwargs["crs"] = f"EPSG:{kwargs['epsg']}"
kwargs.pop("epsg")
if "resampling" in kwargs:
if isinstance(kwargs["resampling"], int):
kwargs["resampling"] = Resampling(kwargs["resampling"]).name
kwargs["chunks"] = kwargs.get("chunks", dict(x="auto", y="auto", time=1))
if "geobox" in kwargs.keys() and "geopolygon" in kwargs.keys():
kwargs.pop("geopolygon")
ds = stac.load(
items_collection,
bands=assets,
preserve_original_order=True,
dtype=dtype,
groupby=None,
**kwargs,
)
if properties:
metadata = defaultdict(list)
for i in items_collection:
# if properties is only a key
if isinstance(properties, str):
metadata[properties].append(i.properties[properties])
else:
for k, v in i.properties.items():
if isinstance(properties, list):
if k not in properties:
continue
if isinstance(v, list):
v = str(v)
metadata[k].append(v)
# to avoid mismatch if some properties are not available on all items
df = pd.DataFrame.from_dict(metadata, orient="index").T
# convert to xarray needs
metadata = {k: ("time", v.tolist()) for k, v in df.items()}
# assign metadata as coords
ds = ds.assign_coords(**metadata)
if "latitude" in ds.coords and "longitude" in ds.coords:
ds = ds.rename({"latitude": "y", "longitude": "x"})
ds = ds.chunk(kwargs["chunks"])
return ds
def _cube_stackstac(items_collection, assets=None, times=None, **kwargs):
from stackstac import stack
if "epsg" in kwargs:
kwargs["epsg"] = int(kwargs["epsg"])
if kwargs.get("geobox") is not None:
kwargs["resolution"] = kwargs["geobox"].resolution.x
kwargs["epsg"] = kwargs["geobox"].crs.to_epsg()
if "geobox" in kwargs.keys():
kwargs.pop("geobox")
ds = stack(
items_collection,
assets=assets,
rescale=False,
xy_coords="center",
**kwargs,
)
ds = ds.to_dataset(dim="band")
if "band" in ds.sizes:
ds = ds.drop_dims("band")
for data_vars in ds.data_vars:
ds[data_vars] = ds[data_vars].rio.write_crs(ds.rio.crs)
if times:
ds["time"] = times
return ds
[docs]
def datacube(
items_collection=None,
bbox=None,
intersects=None,
assets: list | dict = None,
engine="odc",
rescale=True,
groupby_date="mean",
common_band_names=True,
cross_calibration_items: list | None = None,
properties: (bool | str | list) = False,
**kwargs,
):
logging.info(f"Building datacube with {len(items_collection)} items")
times = [
np.datetime64(d.datetime.strftime("%Y-%m-%d %X.%f")).astype("datetime64[ns]")
for d in items_collection
if "datetime" in d.__dict__
]
if len(times) == 0:
times = None
engines = {"odc": _cube_odc, "stackstac": _cube_stackstac}
if engine not in engines:
raise NotImplementedError(
f"Engine '{engine}' not supported. Only {' and '.join(list(engines.keys()))} are currently supported."
)
if common_band_names and not isinstance(assets, dict):
aM = AssetMapper()
assets = aM.map_collection_assets(items_collection[0].collection_id, assets)
if isinstance(assets, dict):
assets_keys = list(assets.keys())
if engine == "odc" and intersects is not None:
kwargs["geopolygon"] = GeometryManager(intersects).to_geopandas()
if engine == "stackstac" and intersects is not None:
kwargs["bounds_latlon"] = list(
GeometryManager(intersects).to_geopandas().to_crs(epsg=4326).total_bounds
)
# create datacube using the defined engine (default is odc stac)
ds = engines[engine](
items_collection,
assets=assets_keys if isinstance(assets, dict) else assets,
times=times,
properties=properties,
**kwargs,
)
# check nodata per asset (data_vars)
# TODO : replace the original no_data with a defined value
# (like min float) because of rescale
nodatas = {}
for ds_asset in ds.data_vars:
for item in items_collection:
empty_dict_list = []
band_idx = 1
asset = ds_asset
if len(parts := ds_asset.split(".")) == 2:
index = parts[1][-1]
is_band = isinstance(index, int) or (
isinstance(index, str) and index.isdigit()
)
if is_band:
asset, band_idx = asset.split(".")
band_idx = int(band_idx)
for i in range(band_idx + 1):
empty_dict_list.append({})
if asset not in item.assets.keys():
continue
nodata = (
item.assets[asset]
.extra_fields.get("raster:bands", empty_dict_list)[band_idx - 1]
.get("nodata")
)
if nodata == 0 or nodata:
nodatas.update({ds_asset: nodata})
break
# drop na dates
ds = ds.isel(dict(time=np.where(~np.isnan(ds.time))[0]))
if groupby_date:
if ds.time.size != np.unique(ds.time).size:
ds = ds.groupby("time")
ds = getattr(ds, groupby_date)()
# get grouped value if several tiles at same exactly date
if bbox is not None and intersects is None:
intersects = _bbox_to_intersects(bbox)
if intersects is not None:
intersects = GeometryManager(intersects).to_geopandas()
if isinstance(intersects, gpd.GeoDataFrame):
# optimize by perclipping using bbox
# no need anymore thanks to geobox/geopolygon in doc
# ds = ds.rio.clip_box(*intersects.to_crs(ds.rio.crs).total_bounds)
ds = ds.rio.clip(intersects.to_crs(ds.rio.crs).geometry)
# apply nodata
ds = _apply_nodata(ds, nodatas)
if rescale:
ds = rescale_assets_with_items(items_collection, ds, assets=assets)
if engine == "stackstac":
ds = _autofix_unfrozen_coords_dtype(ds)
if cross_calibration_items is not None and len(cross_calibration_items) > 0:
ds = Harmonizer.harmonize(items_collection, ds, cross_calibration_items, assets)
if groupby_date:
if ds.time.size != np.unique(ds.time.dt.strftime("%Y%m%d")).size:
ds = ds.groupby("time.date")
ds = getattr(ds, groupby_date)().rename(dict(date="time"))
ds["time"] = ds.time.astype("<M8[ns]")
if isinstance(assets, dict):
ds = ds.rename(assets)
for coord in ds.coords:
if ds.coords[coord].values.shape == ():
continue
if isinstance(ds.coords[coord].values[0], (list, dict)):
ds.coords[coord].values = [
json.dumps(ds.coords[coord].values[idx])
for idx in range(ds.coords[coord].size)
]
return ds
def rescale_assets_with_items(
items_collection: list,
ds: xr.Dataset,
assets: None | list[str] = None,
boa_offset_applied_control: bool = True,
boa_offset_applied_force_by_date: bool = True,
) -> xr.Dataset:
"""
Rescale assets in a dataset based on collection items' metadata.
Parameters
----------
items_collection : List
Collection of items containing asset scaling information.
ds : xarray.Dataset
Input dataset to be rescaled.
assets : List[str], optional
List of assets to rescale. If None, uses all dataset variables.
boa_offset_applied_control : bool, default True
Apply Bottom of Atmosphere (BOA) offset control for Sentinel-2 L2A data.
boa_offset_applied_force_by_date : bool, default True
Force BOA offset application for dates after 2022-02-28.
Returns
-------
xarray.Dataset
Rescaled dataset with applied offsets and scales.
Raises
------
ValueError
If there's a mismatch between items and datacube time or dates.
"""
logging.info("Rescaling dataset")
# Deduplicate items by datetime
unique_items = {}
for item in items_collection:
unique_items.setdefault(item.datetime, item)
items_collection = list(unique_items.values())
# Validate items match dataset time
if len(items_collection) != ds.time.size:
raise ValueError(
"Mismatch between items and datacube time. Set rescale to False."
)
# Prepare assets list
assets = assets or list(ds.data_vars.keys())
scales: dict[str, dict[float, dict[float, list]]] = {}
# Process scaling for each time step
for idx, time in enumerate(ds.time.values):
item = items_collection[idx]
# Date validation
if pd.Timestamp(time).strftime("%Y%m%d") != item.datetime.strftime("%Y%m%d"):
raise ValueError(
"Mismatch between items and datacube dates. Set rescale to False."
)
# BOA offset handling for Sentinel-2 L2A
boa_offset_applied = item.properties.get(
"earthsearch:boa_offset_applied", False
)
if boa_offset_applied_control and item.collection_id == "sentinel-2-l2a":
if boa_offset_applied_force_by_date:
boa_offset_applied = pd.Timestamp(time) >= pd.Timestamp("2022-02-28")
# Process each asset
for ds_asset in assets:
# Handle multi-band assets
asset, band_idx = _parse_asset_band(ds_asset)
if asset not in item.assets:
continue
raster_bands = item.assets[asset].extra_fields.get("raster:bands", [])
if not raster_bands or len(raster_bands) < band_idx:
continue
rasterbands = raster_bands[band_idx - 1]
scale = rasterbands.get("scale", 1)
offset = rasterbands.get("offset", 0)
# Special handling for Sentinel-2 BOA offset
if (
item.collection_id == "sentinel-2-l2a"
and boa_offset_applied_control
and ds_asset
in [
"blue",
"red",
"green",
"nir",
"nir08",
"nir09",
"swir16",
"swir22",
"rededge1",
"rededge2",
"rededge3",
]
and boa_offset_applied
):
offset = 0
# Track scaling parameters
scales.setdefault(ds_asset, {}).setdefault(scale, {}).setdefault(
offset, []
).append(time)
# Apply rescaling
if scales:
scaled_assets = []
for asset, scale_data in scales.items():
asset_scaled = []
for scale, offset_data in scale_data.items():
for offset, times in offset_data.items():
mask = np.in1d(ds.time, times)
asset_scaled.append(ds[[asset]].isel(time=mask) * scale + offset)
scaled_assets.append(xr.concat(asset_scaled, dim="time"))
# Merge scaled assets
ds_scaled = xr.merge(scaled_assets).sortby("time")
# Preserve unscaled variables
missing_vars = [var for var in ds.data_vars if var not in scales]
if missing_vars:
ds_scaled = xr.merge([ds_scaled, ds[missing_vars]])
ds_scaled.attrs = ds.attrs
ds = ds_scaled
logging.info("Rescaling complete")
return ds.sortby("time")
def _parse_asset_band(ds_asset: str) -> tuple[str, int]:
"""
Parse asset and band index from asset name.
Parameters
----------
ds_asset : str
Asset name, potentially with band index.
Returns
-------
tuple[str, int]
Tuple of (asset name, band index)
"""
parts = ds_asset.split(".")
if len(parts) == 2 and parts[1][-1].isdigit():
return parts[0], int(parts[1][-1])
return ds_asset, 1
def _propagade_rio(src, ds):
coords = ["epsg", "spatial_ref"]
for coord in coords:
if coord in src:
ds = ds.assign_coords(coord=src[coord])
return ds
def _drop_unfrozen_coords(ds):
unfrozen_coords = [i for i in list(ds.coords) if i not in ds.sizes]
ds = ds.drop(unfrozen_coords)
return ds
def _common_data_vars(*cubes):
data_vars = list(set([k for cube in cubes for k in list(cube.data_vars.keys())]))
return data_vars
def _groupby(ds, by="time.date", how="mean"):
condition = getattr(ds, by).size != np.unique(getattr(ds, by)).size
if condition:
ds = ds.groupby(by)
ds = getattr(ds, how)()
if by == "time.date":
ds = ds.rename(dict(date="time"))
ds["time"] = ds.time.astype("<M8[ns]")
return ds
def _have_same_xy(*cubes):
x_size = list(set(cube.sizes["x"] for cube in cubes))
y_size = list(set(cube.sizes["y"] for cube in cubes))
if len(x_size) == 1 and len(y_size) == 1:
return True
return False