from __future__ import annotations from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timedelta, timezone from os import environ from typing import Iterator from gibil.classes.models import WeatherForecastPoint, WeatherForecastRun, WeatherResolvedTruth from gibil.classes.weather.display import WeatherDisplayDataset class WeatherStoreConfigurationError(RuntimeError): pass @dataclass(frozen=True) class WeatherStoreConfig: database_url: str @classmethod def from_env(cls) -> "WeatherStoreConfig": database_url = environ.get("ASTRAPE_DATABASE_URL") if not database_url: raise WeatherStoreConfigurationError( "ASTRAPE_DATABASE_URL is required for weather storage" ) return cls(database_url=database_url) class WeatherStore: """Persists external weather forecasts and resolved truth in TimescaleDB.""" def __init__(self, config: WeatherStoreConfig) -> None: self.config = config @classmethod def from_env(cls) -> "WeatherStore": return cls(WeatherStoreConfig.from_env()) def initialize(self) -> None: with self._connection() as connection: with connection.cursor() as cursor: cursor.execute("CREATE EXTENSION IF NOT EXISTS timescaledb") cursor.execute( """ CREATE TABLE IF NOT EXISTS weather_forecast_points ( issued_at TIMESTAMPTZ NOT NULL, target_at TIMESTAMPTZ NOT NULL, horizon_hours INTEGER NOT NULL, source TEXT NOT NULL, latitude DOUBLE PRECISION NOT NULL, longitude DOUBLE PRECISION NOT NULL, temperature_c DOUBLE PRECISION, shortwave_radiation_w_m2 DOUBLE PRECISION, cloud_cover_pct DOUBLE PRECISION, inserted_at TIMESTAMPTZ NOT NULL DEFAULT now(), PRIMARY KEY (issued_at, target_at, source) ) """ ) cursor.execute( """ SELECT create_hypertable( 'weather_forecast_points', 'target_at', if_not_exists => TRUE ) """ ) cursor.execute( """ CREATE TABLE IF NOT EXISTS weather_resolved_truth ( resolved_at TIMESTAMPTZ NOT NULL, source TEXT NOT NULL, temperature_c DOUBLE PRECISION, shortwave_radiation_w_m2 DOUBLE PRECISION, cloud_cover_pct DOUBLE PRECISION, inserted_at TIMESTAMPTZ NOT NULL DEFAULT now(), PRIMARY KEY (resolved_at, source) ) """ ) cursor.execute( """ ALTER TABLE weather_resolved_truth ADD COLUMN IF NOT EXISTS cloud_cover_pct DOUBLE PRECISION """ ) cursor.execute( """ SELECT create_hypertable( 'weather_resolved_truth', 'resolved_at', if_not_exists => TRUE ) """ ) connection.commit() def save_forecast_run(self, forecast_run: WeatherForecastRun) -> int: rows = [ ( point.issued_at, point.target_at, point.horizon_hours, forecast_run.source, forecast_run.latitude, forecast_run.longitude, point.temperature_c, point.shortwave_radiation_w_m2, point.cloud_cover_pct, ) for point in forecast_run.points ] if not rows: return 0 with self._connection() as connection: with connection.cursor() as cursor: cursor.executemany( """ INSERT INTO weather_forecast_points ( issued_at, target_at, horizon_hours, source, latitude, longitude, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) ON CONFLICT (issued_at, target_at, source) DO UPDATE SET horizon_hours = EXCLUDED.horizon_hours, latitude = EXCLUDED.latitude, longitude = EXCLUDED.longitude, temperature_c = EXCLUDED.temperature_c, shortwave_radiation_w_m2 = EXCLUDED.shortwave_radiation_w_m2, cloud_cover_pct = EXCLUDED.cloud_cover_pct, inserted_at = now() """, rows, ) connection.commit() return len(rows) def save_resolved_truth(self, truth_points: list[WeatherResolvedTruth]) -> int: rows = [ ( point.resolved_at, point.source, point.temperature_c, point.shortwave_radiation_w_m2, point.cloud_cover_pct, ) for point in truth_points ] if not rows: return 0 with self._connection() as connection: with connection.cursor() as cursor: cursor.executemany( """ INSERT INTO weather_resolved_truth ( resolved_at, source, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct ) VALUES (%s, %s, %s, %s, %s) ON CONFLICT (resolved_at, source) DO UPDATE SET temperature_c = EXCLUDED.temperature_c, shortwave_radiation_w_m2 = EXCLUDED.shortwave_radiation_w_m2, cloud_cover_pct = EXCLUDED.cloud_cover_pct, inserted_at = now() """, rows, ) connection.commit() return len(rows) def save_zero_hour_forecast_as_truth( self, forecast_run: WeatherForecastRun ) -> int: truth_points = [ WeatherResolvedTruth( resolved_at=point.issued_at, source="open_meteo_zero_hour", temperature_c=point.temperature_c, shortwave_radiation_w_m2=point.shortwave_radiation_w_m2, cloud_cover_pct=point.cloud_cover_pct, ) for point in forecast_run.points if point.horizon_hours == 0 ] return self.save_resolved_truth(truth_points) def load_latest_forecast_points( self, start_at: datetime, end_at: datetime, ) -> list[WeatherForecastPoint]: with self._connection() as connection: with connection.cursor() as cursor: cursor.execute( """ SELECT issued_at, target_at, horizon_hours, source, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct FROM ( SELECT issued_at, target_at, horizon_hours, source, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct, ROW_NUMBER() OVER ( PARTITION BY target_at ORDER BY issued_at DESC ) as rn FROM weather_forecast_points WHERE target_at >= %s AND target_at <= %s ) as ranked WHERE rn = 1 ORDER BY target_at LIMIT 5000 """, (start_at, end_at), ) rows = cursor.fetchall() return [ WeatherForecastPoint( issued_at=row[0], target_at=row[1], horizon_hours=row[2], source=row[3], temperature_c=row[4], shortwave_radiation_w_m2=row[5], cloud_cover_pct=row[6], ) for row in rows ] def load_display_dataset( self, start_at: datetime | None = None, end_at: datetime | None = None, ) -> WeatherDisplayDataset: now = datetime.now(timezone.utc) if start_at is None: start_at = now - timedelta(hours=24) if end_at is None: end_at = now + timedelta(hours=48) with self._connection() as connection: with connection.cursor() as cursor: cursor.execute( """ SELECT issued_at, target_at, horizon_hours, source, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct FROM ( SELECT issued_at, target_at, horizon_hours, source, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct, ROW_NUMBER() OVER (PARTITION BY target_at, horizon_hours ORDER BY issued_at DESC) as rn FROM weather_forecast_points WHERE target_at >= %s AND target_at <= %s ) as ranked WHERE rn = 1 ORDER BY target_at, horizon_hours LIMIT 5000 """, (start_at, end_at), ) forecast_rows = cursor.fetchall() cursor.execute( """ SELECT resolved_at, source, temperature_c, shortwave_radiation_w_m2, cloud_cover_pct FROM weather_resolved_truth WHERE resolved_at >= %s AND resolved_at <= %s ORDER BY resolved_at LIMIT 5000 """, (start_at, now), ) truth_rows = cursor.fetchall() return WeatherDisplayDataset( forecast_points=[ WeatherForecastPoint( issued_at=row[0], target_at=row[1], horizon_hours=row[2], source=row[3], temperature_c=row[4], shortwave_radiation_w_m2=row[5], cloud_cover_pct=row[6], ) for row in forecast_rows ], resolved_truth=[ WeatherResolvedTruth( resolved_at=row[0], source=row[1], temperature_c=row[2], shortwave_radiation_w_m2=row[3], cloud_cover_pct=row[4], ) for row in truth_rows ], ) @contextmanager def _connection(self) -> Iterator[object]: try: import psycopg except ImportError as error: raise WeatherStoreConfigurationError( "Install dependencies with `python3 -m pip install -r requirements.txt`" ) from error with psycopg.connect(self.config.database_url) as connection: yield connection