import logging
from collections.abc import Callable, Iterable, MutableMapping, Sequence
from dataclasses import dataclass
from time import perf_counter
from typing import Any, Generic
import pandas as pd
from rics.collections.dicts import flatten_dict
from time_split._compat import fmt_sec
from ..settings import log_split_progress as settings
from ..types import DatetimeSplitBounds, FormatMetrics, GetMetrics, LoggerArg, MetricsType, SplitProgressExtras
from ._to_string import _PrettyTimestamp
[docs]
def log_split_progress(
splits: Sequence[DatetimeSplitBounds],
*,
logger: LoggerArg = "time_split",
start_level: int = logging.INFO,
end_level: int = logging.INFO,
extra: dict[str, Any] | None = None,
get_metrics: GetMetrics[MetricsType] | None = None,
) -> Iterable[DatetimeSplitBounds]:
"""Log iteration progress.
Args:
splits: Splits to iterate over.
logger: Logger or logger name to use.
start_level: Log level to use for the :attr:`fold-begin message <.settings.log_split_progress.START_MESSAGE>`.
end_level: Log level to use for the :attr:`fold-end message <.settings.log_split_progress.END_MESSAGE>`.
extra: Immutable, user-defined `extra`-arguments to use when logging, merged with progress-related extras (see
:class:`~time_split.types.SplitProgressExtras`).
get_metrics: A callable ``(training_date) -> fold_metrics | str`` (see :attr:`~.DatetimeSplit.training_date`).
If given, metrics are added to the :attr:`fold-end message <.settings.log_split_progress.END_MESSAGE>`. The
message is formatted using the :func:`default formatter <.support.default_metrics_formatter>` unless
:attr:`~.settings.log_split_progress.FORMAT_METRICS` is set. If this callback returns a ``str`` argument,
the :func:`default formatter <.support.default_metrics_formatter>` will assume that the metrics are
pre-formatted, simply appending the formatted metrics to the
:attr:`fold-end message <.settings.log_split_progress.END_MESSAGE>` as-is.
Returns:
An iterable over `splits`.
Examples:
Configuring the `logger` name and
:attr:`fold-begin message <.settings.log_split_progress.START_MESSAGE>` log level.
>>> from time_split import split, log_split_progress
>>> schedule = ["2023-08-16", "2023-08-17 12", "2023-08-19"]
>>> tracked_splits = log_split_progress(
... split(schedule),
... logger="progress",
... start_level=logging.DEBUG,
... )
>>> list(splits) # doctest: +SKIP
[progress:DEBUG] Begin fold 1/2: '2023-08-09' <= [schedule: '2023-08-16' (Wednesday)] < '2023-08-17 12:00:00'.
[progress:INFO] Finished fold 1/2 [schedule: '2023-08-16' (Wednesday)] after 5m 18s.
[progress:DEBUG] Begin fold 2/2: '2023-08-10 12:00:00' <= [schedule: '2023-08-17 12:00:00' (Thursday)] < '2023-08-19'.
[progress:INFO] Finished fold 2/2 [schedule: '2023-08-17 12:00:00' (Thursday)] after 4m 3s.
Using the `get_metrics` callback argument.
>>> metrics = {
... "2023-08-16 00:00:00": {"rmse": {"train": 0.11, "test": 0.5}},
... "2023-08-17 12:00:00": {"rmse": {"test": 0.5, "future": 20.19}},
... }
>>> tracked_splits = log_split_progress(
... split(schedule),
... get_metrics=lambda key: metrics[str(key)],
... )
>>> list(tracked_splits) # doctest: +SKIP
[time_split:INFO] Begin fold 1/2: '2023-08-09' <= [schedule: '2023-08-16' (Wednesday)] < '2023-08-17 12:00:00'.
[time_split:INFO] Finished fold 1/2 [schedule: '2023-08-16' (Wednesday)] after 5m 18s. Fold metrics:
rmse.train 0.11
rmse.test 0.5
[time_split:INFO] Begin fold 2/2: '2023-08-10 12:00:00' <= [schedule: '2023-08-17 12:00:00' (Thursday)] < '2023-08-19'.
[time_split:INFO] Finished fold 2/2 [schedule: '2023-08-17 12:00:00' (Thursday)] after 4m 3s. Fold metrics:
rmse.test 0.5
rmse.future 20.19
Formatting was done using the :func:`default formatter <.support.default_metrics_formatter>`, since the
:attr:`~.settings.log_split_progress.FORMAT_METRICS` setting is ``None``.
"""
logger = logging.getLogger(logger) if isinstance(logger, str) else logger
if isinstance(logger, logging.LoggerAdapter) and not hasattr(logger, "merge_extra"):
# Backport of https://github.com/python/cpython/pull/107292
logger = _MergingLoggerAdapter(logger.logger, logger.extra)
track = _ProgressTracker(
logger=logger,
fold_format=settings.FOLD_FORMAT,
start_level=start_level,
start_message=settings.START_MESSAGE,
end_level=end_level,
end_message=settings.END_MESSAGE,
seconds_formatter=settings.SECONDS_FORMATTER or fmt_sec,
user_extra={} if extra is None else extra.copy(), # Not actually immutable; deepcopy can be very expensive.
get_metrics=get_metrics,
format_metrics=settings.FORMAT_METRICS or default_metrics_formatter,
)
return track(splits)
@dataclass(frozen=True)
class _ProgressTracker(Generic[MetricsType]):
logger: logging.Logger | logging.LoggerAdapter # type: ignore[type-arg]
fold_format: str
start_level: int
start_message: str
end_level: int
end_message: str
seconds_formatter: Callable[[float], str]
user_extra: dict[str, Any]
get_metrics: GetMetrics[MetricsType] | None
format_metrics: FormatMetrics[MetricsType]
def __call__(self, splits: Sequence[DatetimeSplitBounds]) -> Iterable[DatetimeSplitBounds]:
n_splits = len(splits)
for n, split in enumerate(splits, start=1):
default_extras = SplitProgressExtras(
n=n,
n_splits=n_splits,
start=split.start.isoformat(),
mid=split.mid.isoformat(),
end=split.end.isoformat(),
)
extra = {**self.user_extra, **default_extras}
kwargs: dict[str, Any] = dict(
n=n,
n_splits=n_splits,
start=_PrettyTimestamp(split.start),
mid=_PrettyTimestamp(split.mid),
end=_PrettyTimestamp(split.end),
**self.user_extra,
)
kwargs.update(fold=self.fold_format.format(**kwargs))
self.logger.log(self.start_level, self.start_message.format(**kwargs), extra=extra)
# Yield split and count user time.
start = perf_counter()
yield split
seconds = round(perf_counter() - start, 6)
kwargs.update(
seconds=seconds,
formatted_seconds=self.seconds_formatter(seconds),
)
msg = self.end_message.format(**kwargs)
if self.get_metrics is not None:
extra["metrics"] = self.get_metrics(split.mid)
msg = self.format_metrics(msg, extra["metrics"])
extra.update(seconds=seconds)
self.logger.log(self.end_level, msg, extra=extra)
class _MergingLoggerAdapter(logging.LoggerAdapter[Any]):
# TODO(3.13): Use merge_extra=True init arg
def process(self, msg: Any, kwargs: MutableMapping[str, Any]) -> tuple[Any, MutableMapping[str, Any]]:
"""See https://github.com/python/cpython/pull/107292."""
kwargs["extra"] = {**self.extra, **kwargs["extra"]} if "extra" in kwargs and self.extra else self.extra
return msg, kwargs
def _convert_dict(metrics: dict[Any, Any]) -> dict[Any, Any] | pd.DataFrame | pd.Series:
original = metrics
scalar_leaves: bool = False
if all(isinstance(v, dict) for v in metrics.values()):
flat = flatten_dict(metrics)
# A single dot usually indicates nesting one level deep; can be formatted as a DataFrame.
if any(k.count(".") != 1 for k in flat) and all(pd.api.types.is_scalar(v) for v in flat.values()):
metrics = flat
elif all(pd.api.types.is_scalar(v) for v in original.values()):
scalar_leaves = True
try:
if scalar_leaves:
return pd.Series(metrics)
df = pd.DataFrame(metrics)
except Exception:
return original
if len(df) > len(df.columns):
df = df.T # Prefer wider horizontal output
return df