Source code for time_split._frontend._weight

from typing import Literal

from pandas import DatetimeIndex, Timedelta

from .._compat import fix_pandas4_warning
from .._support import handle_dask
from ..types import DatetimeIterable, DatetimeSplitCounts, DatetimeSplits


[docs] def fold_weight( splits: DatetimeSplits, *, unit: str | Literal["rows", "hours", "days"] = "hours", available: DatetimeIterable | None = None, ) -> list[DatetimeSplitCounts]: """Compute fold weights. Args: splits: List of :attr:`~time_split.types.DatetimeSplitBounds`. unit: Time unit of the returned count, or `'rows'` (requires `available` data). available: Available data. Required when ``unit='rows'``. Returns: A list of tuples ``[(n_data_units, n_future_data_units), ...]``. Raises: ValueError: if ``unit='rows'`` and ``available=None``. """ if unit == "rows": if available is None: raise ValueError(f"Must provide available data when {unit=}.") return _from_data(available, splits) else: return _from_unit(unit, splits)
def _from_unit(unit: str, splits: DatetimeSplits) -> list[DatetimeSplitCounts]: resolution = Timedelta(1, unit=fix_pandas4_warning(unit)) return [ DatetimeSplitCounts( round((fold.mid - fold.start) / resolution), round((fold.end - fold.mid) / resolution), ) for fold in splits ] def _from_data(available: DatetimeIterable, splits: DatetimeSplits) -> list[DatetimeSplitCounts]: if hasattr(available, "to_series"): time = available.to_series() # Works for Dask and Pandas. elif hasattr(available, "between"): time = available # This is what we want. Dask index has this attribute, but calling it fails. else: time = DatetimeIndex(available).to_series() # Coerce everything else. retval = [] for fold in splits: data = time.between(fold.start, fold.mid, inclusive="left").sum() future_data = time.between(fold.mid, fold.end, inclusive="left").sum() data = handle_dask(data) future_data = handle_dask(future_data) retval.append(DatetimeSplitCounts(data, future_data=future_data)) return retval