from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone from os import environ from typing import Iterator from gibil.classes.models import NetPowerForecastRun, PowerForecastRun class OracleStoreConfigurationError(RuntimeError): pass @dataclass(frozen=True) class OracleStoreConfig: database_url: str @classmethod def from_env(cls) -> "OracleStoreConfig": database_url = environ.get("ASTRAPE_DATABASE_URL") if not database_url: raise OracleStoreConfigurationError( "ASTRAPE_DATABASE_URL is required for oracle storage" ) return cls(database_url=database_url) class OracleStore: """Persists generated oracle projection curves for later evaluation.""" def __init__(self, config: OracleStoreConfig) -> None: self.config = config @classmethod def from_env(cls) -> "OracleStore": return cls(OracleStoreConfig.from_env()) def initialize(self) -> None: with self._connection() as connection: with connection.cursor() as cursor: cursor.execute("CREATE EXTENSION IF NOT EXISTS timescaledb") cursor.execute( """ CREATE TABLE IF NOT EXISTS oracle_power_forecast_points ( issued_at TIMESTAMPTZ NOT NULL, target_at TIMESTAMPTZ NOT NULL, kind TEXT NOT NULL, source TEXT NOT NULL, model_version TEXT NOT NULL, horizon_minutes INTEGER NOT NULL, expected_power_w DOUBLE PRECISION NOT NULL, p10_power_w DOUBLE PRECISION NOT NULL, p50_power_w DOUBLE PRECISION NOT NULL, p90_power_w DOUBLE PRECISION NOT NULL, confidence DOUBLE PRECISION NOT NULL, inserted_at TIMESTAMPTZ NOT NULL DEFAULT now(), PRIMARY KEY (issued_at, target_at, kind, source, model_version) ) """ ) cursor.execute( """ SELECT create_hypertable( 'oracle_power_forecast_points', 'target_at', if_not_exists => TRUE ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS oracle_net_forecast_points ( issued_at TIMESTAMPTZ NOT NULL, target_at TIMESTAMPTZ NOT NULL, source TEXT NOT NULL, horizon_minutes INTEGER NOT NULL, expected_net_power_w DOUBLE PRECISION NOT NULL, safe_net_power_w DOUBLE PRECISION NOT NULL, p10_net_power_w DOUBLE PRECISION, p50_net_power_w DOUBLE PRECISION, p90_net_power_w DOUBLE PRECISION, solar_p50_power_w DOUBLE PRECISION NOT NULL, load_p50_power_w DOUBLE PRECISION NOT NULL, solar_p10_power_w DOUBLE PRECISION NOT NULL, solar_p90_power_w DOUBLE PRECISION, load_p10_power_w DOUBLE PRECISION, load_p90_power_w DOUBLE PRECISION NOT NULL, inserted_at TIMESTAMPTZ NOT NULL DEFAULT now(), PRIMARY KEY (issued_at, target_at, source) ) """ ) cursor.execute( """ ALTER TABLE oracle_net_forecast_points ADD COLUMN IF NOT EXISTS p10_net_power_w DOUBLE PRECISION """ ) cursor.execute( """ ALTER TABLE oracle_net_forecast_points ADD COLUMN IF NOT EXISTS p50_net_power_w DOUBLE PRECISION """ ) cursor.execute( """ ALTER TABLE oracle_net_forecast_points ADD COLUMN IF NOT EXISTS p90_net_power_w DOUBLE PRECISION """ ) cursor.execute( """ ALTER TABLE oracle_net_forecast_points ADD COLUMN IF NOT EXISTS solar_p90_power_w DOUBLE PRECISION """ ) cursor.execute( """ ALTER TABLE oracle_net_forecast_points ADD COLUMN IF NOT EXISTS load_p10_power_w DOUBLE PRECISION """ ) cursor.execute( """ SELECT create_hypertable( 'oracle_net_forecast_points', 'target_at', if_not_exists => TRUE ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS oracle_forecast_evaluations ( issued_at TIMESTAMPTZ NOT NULL, target_at TIMESTAMPTZ NOT NULL, kind TEXT NOT NULL, source TEXT NOT NULL, model_version TEXT NOT NULL, horizon_minutes INTEGER NOT NULL, expected_power_w DOUBLE PRECISION NOT NULL, p10_power_w DOUBLE PRECISION, p50_power_w DOUBLE PRECISION, p90_power_w DOUBLE PRECISION, realized_power_w DOUBLE PRECISION, error_w DOUBLE PRECISION, absolute_error_w DOUBLE PRECISION, absolute_pct_error DOUBLE PRECISION, covered_by_p10_p90 BOOLEAN, sample_count INTEGER NOT NULL DEFAULT 0, evaluated_at TIMESTAMPTZ NOT NULL DEFAULT now(), inserted_at TIMESTAMPTZ NOT NULL DEFAULT now(), updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), PRIMARY KEY ( issued_at, target_at, kind, source, model_version ) ) """ ) cursor.execute( """ SELECT create_hypertable( 'oracle_forecast_evaluations', 'target_at', if_not_exists => TRUE ) """ ) cursor.execute( """ CREATE INDEX IF NOT EXISTS oracle_forecast_evaluations_kind_horizon_idx ON oracle_forecast_evaluations ( kind, horizon_minutes, target_at DESC ) """ ) connection.commit() def save_runs( self, solar_run: PowerForecastRun, load_run: PowerForecastRun, net_run: NetPowerForecastRun, ) -> int: self.initialize() power_rows = [ ( run.issued_at, point.target_at, run.kind.value, run.source, run.model_version, point.horizon_minutes, point.expected_power_w, point.p10_power_w, point.p50_power_w, point.p90_power_w, point.confidence, ) for run in (solar_run, load_run) for point in run.points ] net_rows = [ ( net_run.issued_at, point.target_at, net_run.source, point.horizon_minutes, point.expected_net_power_w, point.safe_net_power_w, point.p10_net_power_w, point.p50_net_power_w, point.p90_net_power_w, point.solar_p50_power_w, point.load_p50_power_w, point.solar_p10_power_w, point.solar_p90_power_w, point.load_p10_power_w, point.load_p90_power_w, ) for point in net_run.points ] with self._connection() as connection: with connection.cursor() as cursor: cursor.executemany( """ INSERT INTO oracle_power_forecast_points ( issued_at, target_at, kind, source, model_version, horizon_minutes, expected_power_w, p10_power_w, p50_power_w, p90_power_w, confidence ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (issued_at, target_at, kind, source, model_version) DO UPDATE SET horizon_minutes = EXCLUDED.horizon_minutes, expected_power_w = EXCLUDED.expected_power_w, p10_power_w = EXCLUDED.p10_power_w, p50_power_w = EXCLUDED.p50_power_w, p90_power_w = EXCLUDED.p90_power_w, confidence = EXCLUDED.confidence, inserted_at = now() """, power_rows, ) cursor.executemany( """ INSERT INTO oracle_net_forecast_points ( issued_at, target_at, source, horizon_minutes, expected_net_power_w, safe_net_power_w, p10_net_power_w, p50_net_power_w, p90_net_power_w, solar_p50_power_w, load_p50_power_w, solar_p10_power_w, solar_p90_power_w, load_p10_power_w, load_p90_power_w ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (issued_at, target_at, source) DO UPDATE SET horizon_minutes = EXCLUDED.horizon_minutes, expected_net_power_w = EXCLUDED.expected_net_power_w, safe_net_power_w = EXCLUDED.safe_net_power_w, p10_net_power_w = EXCLUDED.p10_net_power_w, p50_net_power_w = EXCLUDED.p50_net_power_w, p90_net_power_w = EXCLUDED.p90_net_power_w, solar_p50_power_w = EXCLUDED.solar_p50_power_w, load_p50_power_w = EXCLUDED.load_p50_power_w, solar_p10_power_w = EXCLUDED.solar_p10_power_w, solar_p90_power_w = EXCLUDED.solar_p90_power_w, load_p10_power_w = EXCLUDED.load_p10_power_w, load_p90_power_w = EXCLUDED.load_p90_power_w, inserted_at = now() """, net_rows, ) connection.commit() return len(power_rows) + len(net_rows) def load_recent_net_runs( self, lookback: timedelta = timedelta(hours=6), limit: int = 6, ) -> list[dict[str, object]]: return self.load_lagged_net_runs( lag_hours=[hour for hour in (1, 2, 6, 24, 48) if hour <= lookback.total_seconds() / 3600], tolerance=timedelta(minutes=45), limit=limit, ) def load_lagged_net_runs( self, lag_hours: list[int] | None = None, tolerance: timedelta = timedelta(minutes=45), limit: int = 5, ) -> list[dict[str, object]]: if lag_hours is None: lag_hours = [1, 2, 6, 24, 48] now = datetime.now(timezone.utc) selected: list[tuple[int, datetime]] = [] used_issued_at: set[datetime] = set() with self._connection() as connection: with connection.cursor() as cursor: for lag_hour in lag_hours: target_issued_at = now - timedelta(hours=lag_hour) cursor.execute( """ SELECT issued_at FROM oracle_net_forecast_points WHERE issued_at BETWEEN %s AND %s GROUP BY issued_at ORDER BY abs(extract(epoch FROM (issued_at - %s))) LIMIT 1 """, ( target_issued_at - tolerance, target_issued_at + tolerance, target_issued_at, ), ) row = cursor.fetchone() if row is None or row[0] in used_issued_at: continue selected.append((lag_hour, row[0])) used_issued_at.add(row[0]) if len(selected) >= limit: break runs: list[dict[str, object]] = [] for lag_hour, issued_at in selected: cursor.execute( """ SELECT target_at, horizon_minutes, expected_net_power_w, safe_net_power_w, COALESCE(p10_net_power_w, safe_net_power_w), COALESCE(p50_net_power_w, expected_net_power_w), p90_net_power_w, solar_p50_power_w, load_p50_power_w, solar_p10_power_w, solar_p90_power_w, load_p10_power_w, load_p90_power_w FROM oracle_net_forecast_points WHERE issued_at = %s AND target_at >= %s ORDER BY target_at """, (issued_at, issued_at), ) points = cursor.fetchall() if not points: continue runs.append( { "lag_hours": lag_hour, "issued_at": issued_at, "points": [ { "target_at": row[0], "horizon_minutes": row[1], "expected_net_power_w": row[2], "safe_net_power_w": row[3], "p10_net_power_w": row[4], "p50_net_power_w": row[5], "p90_net_power_w": row[6], "solar_p50_power_w": row[7], "load_p50_power_w": row[8], "solar_p10_power_w": row[9], "solar_p90_power_w": row[10], "load_p10_power_w": row[11], "load_p90_power_w": row[12], } for row in points ], } ) return runs def load_lagged_power_runs( self, kind: str, lag_hours: list[int] | None = None, tolerance: timedelta = timedelta(minutes=45), limit: int = 5, ) -> list[dict[str, object]]: if kind not in {"solar", "load"}: raise ValueError("kind must be 'solar' or 'load'") if lag_hours is None: lag_hours = [1, 2, 6, 24, 48] now = datetime.now(timezone.utc) selected: list[tuple[int, datetime, str, str, str]] = [] used_keys: set[tuple[datetime, str, str, str]] = set() with self._connection() as connection: with connection.cursor() as cursor: for lag_hour in lag_hours: target_issued_at = now - timedelta(hours=lag_hour) cursor.execute( """ SELECT issued_at, kind, source, model_version FROM oracle_power_forecast_points WHERE kind = %s AND issued_at BETWEEN %s AND %s GROUP BY issued_at, kind, source, model_version ORDER BY abs(extract(epoch FROM (issued_at - %s))) LIMIT 1 """, ( kind, target_issued_at - tolerance, target_issued_at + tolerance, target_issued_at, ), ) row = cursor.fetchone() if row is None: continue key = (row[0], row[1], row[2], row[3]) if key in used_keys: continue selected.append((lag_hour, row[0], row[1], row[2], row[3])) used_keys.add(key) if len(selected) >= limit: break runs: list[dict[str, object]] = [] for lag_hour, issued_at, run_kind, source, model_version in selected: cursor.execute( """ SELECT target_at, horizon_minutes, expected_power_w, p10_power_w, p50_power_w, p90_power_w, confidence FROM oracle_power_forecast_points WHERE issued_at = %s AND kind = %s AND source = %s AND model_version = %s AND target_at >= %s ORDER BY target_at """, (issued_at, run_kind, source, model_version, issued_at), ) points = cursor.fetchall() if not points: continue runs.append( { "lag_hours": lag_hour, "issued_at": issued_at, "kind": run_kind, "source": source, "model_version": model_version, "points": [ { "target_at": row[0], "horizon_minutes": row[1], "expected_power_w": row[2], "p10_power_w": row[3], "p50_power_w": row[4], "p90_power_w": row[5], "confidence": row[6], } for row in points ], } ) return runs def evaluate_due_forecasts( self, actual_window: timedelta = timedelta(minutes=5), lookback: timedelta = timedelta(days=7), limit: int = 1000, ) -> int: self.initialize() start_at = datetime.now(timezone.utc) - lookback with self._connection() as connection: with connection.cursor() as cursor: power_count = self._evaluate_due_power_forecasts( cursor=cursor, actual_window=actual_window, start_at=start_at, limit=limit, ) remaining_limit = max(limit - power_count, 0) net_count = 0 if remaining_limit > 0: net_count = self._evaluate_due_net_forecasts( cursor=cursor, actual_window=actual_window, start_at=start_at, limit=remaining_limit, ) connection.commit() return power_count + net_count def load_evaluation_summary( self, lookback: timedelta = timedelta(days=7), ) -> list[dict[str, object]]: start_at = datetime.now(timezone.utc) - lookback with self._connection() as connection: with connection.cursor() as cursor: cursor.execute( """ WITH bucketed AS ( SELECT *, CASE WHEN horizon_minutes < 120 THEN 1 WHEN horizon_minutes < 240 THEN 2 WHEN horizon_minutes < 480 THEN 3 WHEN horizon_minutes < 960 THEN 4 ELSE 5 END AS horizon_bucket, CASE WHEN horizon_minutes < 120 THEN '0-2h' WHEN horizon_minutes < 240 THEN '2-4h' WHEN horizon_minutes < 480 THEN '4-8h' WHEN horizon_minutes < 960 THEN '8-16h' ELSE '16-24h' END AS horizon_label FROM oracle_forecast_evaluations WHERE target_at >= %s AND realized_power_w IS NOT NULL ) SELECT kind, source, model_version, horizon_bucket, horizon_label, min(horizon_minutes) AS min_horizon_minutes, max(horizon_minutes) AS max_horizon_minutes, count(*) AS evaluated_count, avg(error_w) AS mean_error_w, avg(absolute_error_w) AS mean_absolute_error_w, percentile_cont(0.50) WITHIN GROUP ( ORDER BY absolute_error_w ) AS median_absolute_error_w, avg(absolute_pct_error) AS mean_absolute_pct_error, avg( CASE WHEN covered_by_p10_p90 IS NULL THEN NULL WHEN covered_by_p10_p90 THEN 1.0 ELSE 0.0 END ) AS interval_coverage FROM bucketed GROUP BY kind, source, model_version, horizon_bucket, horizon_label ORDER BY kind, source, model_version, horizon_bucket """, (start_at,), ) rows = cursor.fetchall() return [ { "kind": row[0], "source": row[1], "model_version": row[2], "horizon_bucket": row[3], "horizon_label": row[4], "min_horizon_minutes": row[5], "max_horizon_minutes": row[6], "evaluated_count": row[7], "mean_error_w": row[8], "mean_absolute_error_w": row[9], "median_absolute_error_w": row[10], "mean_absolute_pct_error": row[11], "interval_coverage": row[12], } for row in rows ] def _evaluate_due_power_forecasts( self, cursor: object, actual_window: timedelta, start_at: datetime, limit: int, ) -> int: cursor.execute( """ WITH candidates AS ( SELECT forecast.issued_at, forecast.target_at, forecast.kind, forecast.source, forecast.model_version, forecast.horizon_minutes, forecast.expected_power_w, forecast.p10_power_w, forecast.p50_power_w, forecast.p90_power_w FROM oracle_power_forecast_points AS forecast LEFT JOIN oracle_forecast_evaluations AS evaluation ON evaluation.issued_at = forecast.issued_at AND evaluation.target_at = forecast.target_at AND evaluation.kind = forecast.kind AND evaluation.source = forecast.source AND evaluation.model_version = forecast.model_version WHERE forecast.target_at >= %s AND forecast.target_at <= now() - %s AND ( evaluation.issued_at IS NULL OR evaluation.sample_count = 0 ) ORDER BY forecast.target_at, forecast.issued_at LIMIT %s ), realized AS ( SELECT candidates.*, actual.realized_power_w, actual.sample_count FROM candidates LEFT JOIN LATERAL ( SELECT avg( CASE candidates.kind WHEN 'solar' THEN snapshot.solar_power_w WHEN 'load' THEN snapshot.load_power_w ELSE NULL END ) AS realized_power_w, count(*) FILTER ( WHERE CASE candidates.kind WHEN 'solar' THEN snapshot.solar_power_w WHEN 'load' THEN snapshot.load_power_w ELSE NULL END IS NOT NULL ) AS sample_count FROM sigen_plant_snapshots AS snapshot WHERE snapshot.observed_at >= candidates.target_at AND snapshot.observed_at < candidates.target_at + %s ) AS actual ON TRUE ) INSERT INTO oracle_forecast_evaluations ( issued_at, target_at, kind, source, model_version, horizon_minutes, expected_power_w, p10_power_w, p50_power_w, p90_power_w, realized_power_w, error_w, absolute_error_w, absolute_pct_error, covered_by_p10_p90, sample_count, evaluated_at ) SELECT issued_at, target_at, kind, source, model_version, horizon_minutes, expected_power_w, p10_power_w, p50_power_w, p90_power_w, realized_power_w, realized_power_w - p50_power_w, abs(realized_power_w - p50_power_w), CASE WHEN abs(realized_power_w) < 1 THEN NULL ELSE abs(realized_power_w - p50_power_w) / abs(realized_power_w) END, CASE WHEN realized_power_w IS NULL THEN NULL ELSE realized_power_w BETWEEN p10_power_w AND p90_power_w END, COALESCE(sample_count, 0), now() FROM realized ON CONFLICT ( issued_at, target_at, kind, source, model_version ) DO UPDATE SET horizon_minutes = EXCLUDED.horizon_minutes, expected_power_w = EXCLUDED.expected_power_w, p10_power_w = EXCLUDED.p10_power_w, p50_power_w = EXCLUDED.p50_power_w, p90_power_w = EXCLUDED.p90_power_w, realized_power_w = EXCLUDED.realized_power_w, error_w = EXCLUDED.error_w, absolute_error_w = EXCLUDED.absolute_error_w, absolute_pct_error = EXCLUDED.absolute_pct_error, covered_by_p10_p90 = EXCLUDED.covered_by_p10_p90, sample_count = EXCLUDED.sample_count, evaluated_at = EXCLUDED.evaluated_at, updated_at = now() """, (start_at, actual_window, limit, actual_window), ) return cursor.rowcount def _evaluate_due_net_forecasts( self, cursor: object, actual_window: timedelta, start_at: datetime, limit: int, ) -> int: cursor.execute( """ WITH candidates AS ( SELECT forecast.issued_at, forecast.target_at, 'net'::text AS kind, forecast.source, 'net_forecaster_v1'::text AS model_version, forecast.horizon_minutes, forecast.expected_net_power_w AS expected_power_w, COALESCE(forecast.p10_net_power_w, forecast.safe_net_power_w) AS p10_power_w, COALESCE(forecast.p50_net_power_w, forecast.expected_net_power_w) AS p50_power_w, forecast.p90_net_power_w AS p90_power_w FROM oracle_net_forecast_points AS forecast LEFT JOIN oracle_forecast_evaluations AS evaluation ON evaluation.issued_at = forecast.issued_at AND evaluation.target_at = forecast.target_at AND evaluation.kind = 'net' AND evaluation.source = forecast.source AND evaluation.model_version = 'net_forecaster_v1' WHERE forecast.target_at >= %s AND forecast.target_at <= now() - %s AND ( evaluation.issued_at IS NULL OR evaluation.sample_count = 0 ) ORDER BY forecast.target_at, forecast.issued_at LIMIT %s ), realized AS ( SELECT candidates.*, actual.realized_power_w, actual.sample_count FROM candidates LEFT JOIN LATERAL ( SELECT avg(snapshot.solar_power_w - snapshot.load_power_w) AS realized_power_w, count(*) FILTER ( WHERE snapshot.solar_power_w IS NOT NULL AND snapshot.load_power_w IS NOT NULL ) AS sample_count FROM sigen_plant_snapshots AS snapshot WHERE snapshot.observed_at >= candidates.target_at AND snapshot.observed_at < candidates.target_at + %s ) AS actual ON TRUE ) INSERT INTO oracle_forecast_evaluations ( issued_at, target_at, kind, source, model_version, horizon_minutes, expected_power_w, p10_power_w, p50_power_w, p90_power_w, realized_power_w, error_w, absolute_error_w, absolute_pct_error, covered_by_p10_p90, sample_count, evaluated_at ) SELECT issued_at, target_at, kind, source, model_version, horizon_minutes, expected_power_w, p10_power_w, p50_power_w, p90_power_w, realized_power_w, realized_power_w - p50_power_w, abs(realized_power_w - p50_power_w), CASE WHEN abs(realized_power_w) < 1 THEN NULL ELSE abs(realized_power_w - p50_power_w) / abs(realized_power_w) END, CASE WHEN realized_power_w IS NULL OR p90_power_w IS NULL THEN NULL ELSE realized_power_w BETWEEN p10_power_w AND p90_power_w END, COALESCE(sample_count, 0), now() FROM realized ON CONFLICT ( issued_at, target_at, kind, source, model_version ) DO UPDATE SET horizon_minutes = EXCLUDED.horizon_minutes, expected_power_w = EXCLUDED.expected_power_w, p10_power_w = EXCLUDED.p10_power_w, p50_power_w = EXCLUDED.p50_power_w, p90_power_w = EXCLUDED.p90_power_w, realized_power_w = EXCLUDED.realized_power_w, error_w = EXCLUDED.error_w, absolute_error_w = EXCLUDED.absolute_error_w, absolute_pct_error = EXCLUDED.absolute_pct_error, covered_by_p10_p90 = EXCLUDED.covered_by_p10_p90, sample_count = EXCLUDED.sample_count, evaluated_at = EXCLUDED.evaluated_at, updated_at = now() """, (start_at, actual_window, limit, actual_window), ) return cursor.rowcount @contextmanager def _connection(self) -> Iterator[object]: try: import psycopg except ImportError as error: raise OracleStoreConfigurationError( "Install dependencies with `python3 -m pip install -r requirements.txt`" ) from error with psycopg.connect(self.config.database_url) as connection: yield connection