from __future__ import annotations from bisect import bisect_right from dataclasses import dataclass from datetime import datetime, timedelta, timezone from math import cos, pi, sin from os import environ from typing import Iterator from gibil.classes.env_loader import EnvLoader @dataclass(frozen=True) class UsageSequenceScaleConfig: name: str hours: int step_seconds: int @dataclass(frozen=True) class UsageFeatureToken: name: str value: float @dataclass(frozen=True) class UsageSequenceDatasetConfig: lookback_days: int = 30 future_hours: int = 24 future_step_minutes: int = 15 stride_minutes: int = 15 local_timezone: str = "Europe/Stockholm" past_scales: tuple[UsageSequenceScaleConfig, ...] = ( UsageSequenceScaleConfig(name="recent", hours=2, step_seconds=10), UsageSequenceScaleConfig(name="medium", hours=6, step_seconds=30), UsageSequenceScaleConfig(name="daily", hours=24, step_seconds=120), ) @classmethod def from_env(cls) -> "UsageSequenceDatasetConfig": EnvLoader().load() return cls( lookback_days=int(environ.get("ASTRAPE_USAGE_SEQUENCE_LOOKBACK_DAYS", "30")), future_hours=int(environ.get("ASTRAPE_USAGE_SEQUENCE_FUTURE_HOURS", "24")), future_step_minutes=int( environ.get("ASTRAPE_USAGE_SEQUENCE_FUTURE_STEP_MINUTES", "15") ), stride_minutes=int(environ.get("ASTRAPE_USAGE_SEQUENCE_STRIDE_MINUTES", "15")), local_timezone=environ.get( "ASTRAPE_LOCAL_TIMEZONE", environ.get("TZ", "Europe/Stockholm"), ), ) @dataclass(frozen=True) class UsageSequenceExample: issued_at: datetime past_by_scale: dict[str, list[list[float]]] past_tokens_by_scale: dict[str, list[list[UsageFeatureToken]]] future_features: list[list[float]] future_tokens: list[list[UsageFeatureToken]] targets: list[float] class UsageSequenceDatasetBuilder: """Builds load forecasting windows from Sigen history.""" past_feature_names = [ "load_power_w", "solar_power_w", "grid_import_w", "grid_export_w", "battery_power_w", "battery_soc_pct", "hour_sin", "hour_cos", "dow_sin", "dow_cos", ] future_feature_names = [ "hour_sin", "hour_cos", "dow_sin", "dow_cos", "temperature_c", "shortwave_radiation_w_m2", "cloud_cover_pct", ] def __init__(self, config: UsageSequenceDatasetConfig) -> None: self.config = config @classmethod def from_env(cls) -> "UsageSequenceDatasetBuilder": return cls(UsageSequenceDatasetConfig.from_env()) def build(self, limit: int | None = None) -> list[UsageSequenceExample]: samples_by_scale = { scale.name: self._load_samples(step_seconds=scale.step_seconds) for scale in self.config.past_scales } target_samples = self._load_samples( step_seconds=self.config.future_step_minutes * 60 ) weather_by_target = self._load_weather_forecasts() if not target_samples or any(not samples for samples in samples_by_scale.values()): return [] by_scale = { name: {sample["bucket"]: sample for sample in samples} for name, samples in samples_by_scale.items() } target_by_time = { sample["bucket"]: sample for sample in target_samples } first_available = max(samples[0]["bucket"] for samples in samples_by_scale.values()) last_available = min( [samples[-1]["bucket"] for samples in samples_by_scale.values()] + [target_samples[-1]["bucket"]] ) start_at = first_available + timedelta(hours=self.max_past_hours) end_at = last_available - timedelta(hours=self.config.future_hours) issued_at = self._ceil_time(start_at, self.config.stride_minutes) examples: list[UsageSequenceExample] = [] while issued_at <= end_at: example = self._build_example( issued_at, by_scale, target_by_time, weather_by_target, ) if example is not None: examples.append(example) if limit is not None and len(examples) >= limit: break issued_at += timedelta(minutes=self.config.stride_minutes) return examples def iter_examples(self) -> Iterator[UsageSequenceExample]: for example in self.build(): yield example def _build_example( self, issued_at: datetime, by_scale: dict[str, dict[datetime, dict[str, object]]], target_by_time: dict[datetime, dict[str, object]], weather_by_target: dict[datetime, list[dict[str, object]]], ) -> UsageSequenceExample | None: future_times = [ issued_at + timedelta(minutes=self.config.future_step_minutes * offset) for offset in range(1, self.future_steps + 1) ] past_by_scale: dict[str, list[list[float]]] = {} past_tokens_by_scale: dict[str, list[list[UsageFeatureToken]]] = {} for scale in self.config.past_scales: past_times = [ issued_at - timedelta(seconds=scale.step_seconds * offset) for offset in range(self.past_steps(scale), 0, -1) ] past_rows = [ by_scale[scale.name].get(target_at) for target_at in past_times ] if any(row is None or row["load_power_w"] is None for row in past_rows): return None past_by_scale[scale.name] = [ self._past_features(row) for row in past_rows if row is not None ] past_tokens_by_scale[scale.name] = [ self._past_tokens(row) for row in past_rows if row is not None ] future_rows = [target_by_time.get(target_at) for target_at in future_times] if any(row is None or row["load_power_w"] is None for row in future_rows): return None return UsageSequenceExample( issued_at=issued_at, past_by_scale=past_by_scale, past_tokens_by_scale=past_tokens_by_scale, future_features=[ self._future_features(target_at, issued_at, weather_by_target) for target_at in future_times ], future_tokens=[ self._future_tokens(target_at=target_at, issued_at=issued_at) for target_at in future_times ], targets=[ float(row["load_power_w"]) for row in future_rows if row is not None ], ) @property def max_past_hours(self) -> int: return max(scale.hours for scale in self.config.past_scales) def past_steps(self, scale: UsageSequenceScaleConfig) -> int: return scale.hours * 60 * 60 // scale.step_seconds @property def future_steps(self) -> int: return self.config.future_hours * 60 // self.config.future_step_minutes def _past_features(self, row: dict[str, object]) -> list[float]: time_features = self._time_features(row["bucket"]) return [ self._number(row["load_power_w"]), self._number(row["solar_power_w"]), self._number(row["grid_import_w"]), self._number(row["grid_export_w"]), self._number(row["battery_power_w"]), self._number(row["battery_soc_pct"]), *time_features, ] def _past_tokens(self, row: dict[str, object]) -> list[UsageFeatureToken]: return [] def _time_features(self, value: object) -> list[float]: timestamp = value if not isinstance(timestamp, datetime): raise TypeError("timestamp must be a datetime") local = timestamp.astimezone(timezone.utc) minutes = local.hour * 60 + local.minute minute_angle = 2 * pi * minutes / 1440 dow_angle = 2 * pi * (local.isoweekday() - 1) / 7 return [ sin(minute_angle), cos(minute_angle), sin(dow_angle), cos(dow_angle), ] def _future_features( self, target_at: datetime, issued_at: datetime, weather_by_target: dict[datetime, list[dict[str, object]]], ) -> list[float]: weather = self._weather_for_target( target_at=target_at, issued_at=issued_at, weather_by_target=weather_by_target, ) return [ *self._time_features(target_at), self._number(weather.get("temperature_c")), self._number(weather.get("shortwave_radiation_w_m2")), self._number(weather.get("cloud_cover_pct")), ] def _future_tokens( self, target_at: datetime, issued_at: datetime, ) -> list[UsageFeatureToken]: return [] def _weather_for_target( self, target_at: datetime, issued_at: datetime, weather_by_target: dict[datetime, list[dict[str, object]]], ) -> dict[str, object]: forecast_target_at = self._floor_time(target_at, step_minutes=60) rows = weather_by_target.get(forecast_target_at, []) if not rows: return {} issued_values = [row["issued_at"] for row in rows] index = bisect_right(issued_values, issued_at) - 1 if index < 0: return {} return rows[index] def _load_samples(self, step_seconds: int) -> list[dict[str, object]]: EnvLoader().load() database_url = environ.get("ASTRAPE_DATABASE_URL") if not database_url: raise RuntimeError("ASTRAPE_DATABASE_URL is required") start_at = datetime.now(timezone.utc) - timedelta(days=self.config.lookback_days) bucket = self._bucket_interval(step_seconds) try: import psycopg except ImportError as error: raise RuntimeError( "Install dependencies with `python3 -m pip install -r requirements.txt`" ) from error with psycopg.connect(database_url) as connection: with connection.cursor() as cursor: cursor.execute( f""" SELECT time_bucket('{bucket}', observed_at) AS bucket, avg(load_power_w) AS load_power_w, avg(solar_power_w) AS solar_power_w, avg(grid_import_w) AS grid_import_w, avg(grid_export_w) AS grid_export_w, avg(battery_power_w) AS battery_power_w, avg(battery_soc_pct) AS battery_soc_pct FROM sigen_plant_snapshots WHERE observed_at >= %s AND observed_at <= now() GROUP BY bucket ORDER BY bucket """, (start_at,), ) rows = cursor.fetchall() return [ { "bucket": row[0], "load_power_w": row[1], "solar_power_w": row[2], "grid_import_w": row[3], "grid_export_w": row[4], "battery_power_w": row[5], "battery_soc_pct": row[6], } for row in rows ] def _load_weather_forecasts(self) -> dict[datetime, list[dict[str, object]]]: EnvLoader().load() database_url = environ.get("ASTRAPE_DATABASE_URL") if not database_url: raise RuntimeError("ASTRAPE_DATABASE_URL is required") start_at = datetime.now(timezone.utc) - timedelta(days=self.config.lookback_days) end_at = datetime.now(timezone.utc) + timedelta(hours=self.config.future_hours) try: import psycopg except ImportError as error: raise RuntimeError( "Install dependencies with `python3 -m pip install -r requirements.txt`" ) from error with psycopg.connect(database_url) as connection: with connection.cursor() as cursor: cursor.execute( """ SELECT issued_at, target_at, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct FROM weather_forecast_points WHERE target_at >= %s AND target_at <= %s ORDER BY target_at, issued_at """, (start_at, end_at), ) rows = cursor.fetchall() by_target: dict[datetime, list[dict[str, object]]] = {} for row in rows: by_target.setdefault(row[1], []).append( { "issued_at": row[0], "target_at": row[1], "temperature_c": row[2], "shortwave_radiation_w_m2": row[3], "cloud_cover_pct": row[4], } ) return by_target def _bucket_interval(self, step_seconds: int) -> str: if step_seconds % 60 == 0: return f"{step_seconds // 60} minutes" return f"{step_seconds} seconds" def _ceil_time(self, value: datetime, step_minutes: int) -> datetime: step_seconds = step_minutes * 60 timestamp = value.timestamp() remainder = timestamp % step_seconds if remainder: timestamp += step_seconds - remainder return datetime.fromtimestamp(timestamp, timezone.utc) def _floor_time(self, value: datetime, step_minutes: int) -> datetime: step_seconds = step_minutes * 60 timestamp = value.timestamp() timestamp -= timestamp % step_seconds return datetime.fromtimestamp(timestamp, timezone.utc) def _number(self, value: object) -> float: if value is None: return 0.0 return float(value)