Add new daemons and debug scripts for Sigenergy and Oracle functionalities
- Implement `sigen_daemon.py` to poll Sigenergy plant metrics and store snapshots. - Create `web_daemon.py` for serving a web interface with various endpoints. - Add debug scripts: - `debug_duplicates.py` to find duplicate target times in forecast data. - `debug_energy_forecast.py` to print baseline energy forecast curves. - `debug_oracle_evaluations.py` to run the oracle evaluator. - `debug_sigen.py` to inspect stored Sigenergy plant snapshots. - `debug_weather.py` to trace resolved truth data. - `modbus_test.py` for exploring Sigenergy plants or inverters over Modbus TCP. - Introduce `oracle_evaluator.py` for evaluating stored oracle predictions against actuals. - Add TCN training scripts in `tcn` directory for training usage sequence models.
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Long-running Astrape service entrypoints."""
|
||||
@@ -7,12 +7,12 @@ from sys import stderr
|
||||
from time import sleep
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.weather_builder import (
|
||||
from gibil.classes.weather.builder import (
|
||||
OpenMeteoArchiveClient,
|
||||
OpenMeteoClient,
|
||||
WeatherBuilder,
|
||||
)
|
||||
from gibil.classes.weather_store import WeatherStore
|
||||
from gibil.classes.weather.store import WeatherStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from os import environ
|
||||
from sys import stderr
|
||||
from time import sleep
|
||||
|
||||
from gibil.classes.oracle.builder import EnergyOracleBuilder
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.oracle.store import OracleStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OracleDaemonConfig:
|
||||
poll_seconds: float
|
||||
evaluate_forecasts: bool
|
||||
evaluation_actual_window_minutes: float
|
||||
evaluation_lookback_hours: float
|
||||
evaluation_limit: int
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "OracleDaemonConfig":
|
||||
return cls(
|
||||
poll_seconds=float(environ.get("ASTRAPE_ORACLE_POLL_SECONDS", "300")),
|
||||
evaluate_forecasts=environ.get(
|
||||
"ASTRAPE_ORACLE_EVALUATE_FORECASTS", "1"
|
||||
).lower()
|
||||
not in {"0", "false", "no"},
|
||||
evaluation_actual_window_minutes=float(
|
||||
environ.get("ASTRAPE_ORACLE_EVALUATION_WINDOW_MINUTES", "5")
|
||||
),
|
||||
evaluation_lookback_hours=float(
|
||||
environ.get("ASTRAPE_ORACLE_EVALUATION_LOOKBACK_HOURS", "168")
|
||||
),
|
||||
evaluation_limit=int(environ.get("ASTRAPE_ORACLE_EVALUATION_LIMIT", "1000")),
|
||||
)
|
||||
|
||||
|
||||
class OracleDaemon:
|
||||
"""Periodically stores oracle projection curves for evaluation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: OracleDaemonConfig,
|
||||
builder: EnergyOracleBuilder,
|
||||
store: OracleStore,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.builder = builder
|
||||
self.store = store
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "OracleDaemon":
|
||||
return cls(
|
||||
config=OracleDaemonConfig.from_env(),
|
||||
builder=EnergyOracleBuilder.from_env(),
|
||||
store=OracleStore.from_env(),
|
||||
)
|
||||
|
||||
def run_once(self) -> int:
|
||||
solar_run, load_run, net_run = self.builder.build()
|
||||
saved_count = self.store.save_runs(solar_run, load_run, net_run)
|
||||
if self.config.evaluate_forecasts:
|
||||
from datetime import timedelta
|
||||
|
||||
evaluated_count = self.store.evaluate_due_forecasts(
|
||||
actual_window=timedelta(
|
||||
minutes=self.config.evaluation_actual_window_minutes
|
||||
),
|
||||
lookback=timedelta(hours=self.config.evaluation_lookback_hours),
|
||||
limit=self.config.evaluation_limit,
|
||||
)
|
||||
return saved_count + evaluated_count
|
||||
return saved_count
|
||||
|
||||
def run_forever(self) -> None:
|
||||
self.store.initialize()
|
||||
while True:
|
||||
try:
|
||||
saved_count = self.run_once()
|
||||
print(f"stored_oracle_records={saved_count}", flush=True)
|
||||
except Exception as error:
|
||||
print(f"oracle_poll_error={error}", file=stderr, flush=True)
|
||||
sleep(self.config.poll_seconds)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
EnvLoader().load()
|
||||
args = parse_args()
|
||||
daemon = OracleDaemon.from_env()
|
||||
if args.once:
|
||||
print(f"stored_oracle_records={daemon.run_once()}", flush=True)
|
||||
return
|
||||
daemon.run_forever()
|
||||
except Exception as error:
|
||||
print(f"oracle_daemon_startup_error={error}", file=stderr)
|
||||
raise SystemExit(1) from error
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Store Astrape oracle projection curves."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--once",
|
||||
action="store_true",
|
||||
help="Save one set of oracle curves and exit.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from os import environ
|
||||
from sys import stderr
|
||||
from time import sleep
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.sigen.builder import SigenPlantClient
|
||||
from gibil.classes.sigen.store import SigenStore
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SigenDaemonConfig:
|
||||
poll_seconds: float
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "SigenDaemonConfig":
|
||||
return cls(
|
||||
poll_seconds=float(environ.get("SIGEN_POLL_SECONDS", "5")),
|
||||
)
|
||||
|
||||
|
||||
class SigenDaemon:
|
||||
"""Polls Sigenergy plant metrics and stores normalized snapshots."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: SigenDaemonConfig,
|
||||
plant_client: SigenPlantClient,
|
||||
sigen_store: SigenStore,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.plant_client = plant_client
|
||||
self.sigen_store = sigen_store
|
||||
|
||||
@classmethod
|
||||
def from_env(cls) -> "SigenDaemon":
|
||||
return cls(
|
||||
config=SigenDaemonConfig.from_env(),
|
||||
plant_client=SigenPlantClient.from_env(),
|
||||
sigen_store=SigenStore.from_env(),
|
||||
)
|
||||
|
||||
def initialize(self) -> None:
|
||||
self.sigen_store.initialize()
|
||||
|
||||
def run_once(self) -> int:
|
||||
snapshot = self.plant_client.fetch_snapshot()
|
||||
return self.sigen_store.save_snapshot(snapshot)
|
||||
|
||||
def run_forever(self) -> None:
|
||||
self.initialize()
|
||||
while True:
|
||||
try:
|
||||
saved_count = self.run_once()
|
||||
print(f"stored_sigen_plant_snapshots={saved_count}", flush=True)
|
||||
except Exception as error:
|
||||
print(f"sigen_poll_error={error}", file=stderr, flush=True)
|
||||
|
||||
sleep(self.config.poll_seconds)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
try:
|
||||
EnvLoader().load()
|
||||
daemon = SigenDaemon.from_env()
|
||||
args = parse_args()
|
||||
if args.once:
|
||||
daemon.initialize()
|
||||
saved_count = daemon.run_once()
|
||||
print(f"stored_sigen_plant_snapshots={saved_count}", flush=True)
|
||||
return
|
||||
|
||||
daemon.run_forever()
|
||||
except Exception as error:
|
||||
print(f"sigen_daemon_startup_error={error}", file=stderr)
|
||||
raise SystemExit(1) from error
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Poll Sigenergy plant metrics into Astrape's database."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--once",
|
||||
action="store_true",
|
||||
help="Initialize storage, save one snapshot, and exit.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from importlib import import_module, reload
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
import json
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
|
||||
EnvLoader().load()
|
||||
|
||||
HOST = environ.get("ASTRAPE_WEB_HOST", "0.0.0.0")
|
||||
PORT = int(environ.get("ASTRAPE_WEB_PORT", "8080"))
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
WATCHED_PATHS = [
|
||||
PROJECT_ROOT / "gibil" / "classes" / "webui.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "weather" / "display.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "oracle" / "display.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "oracle" / "quality_display.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "weather" / "store.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "oracle" / "store.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "oracle" / "builder.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "oracle" / "config.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "sigen" / "store.py",
|
||||
]
|
||||
|
||||
|
||||
class AstrapeWebHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
parsed = urlparse(self.path)
|
||||
path = parsed.path
|
||||
|
||||
if path in {"/", "/oracle"}:
|
||||
self._send_html(self._webui().render_page("oracle"))
|
||||
return
|
||||
|
||||
if path == "/weather":
|
||||
self._send_html(self._webui().render_page("weather"))
|
||||
return
|
||||
|
||||
if path == "/quality":
|
||||
self._send_html(self._webui().render_page("quality"))
|
||||
return
|
||||
|
||||
if path == "/api/weather":
|
||||
self._send_json_text(self._webui().weather_payload())
|
||||
return
|
||||
|
||||
if path == "/api/oracle":
|
||||
self._send_json_text(self._webui().oracle_payload())
|
||||
return
|
||||
|
||||
if path == "/api/oracle-quality":
|
||||
params = parse_qs(parsed.query)
|
||||
lookback_hours = self._float_param(params, "lookback_hours", 168)
|
||||
self._send_json_text(
|
||||
self._webui().oracle_quality_payload(
|
||||
lookback_hours=lookback_hours
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if path == "/api/ui-version":
|
||||
self._send_json_text(json.dumps({"version": self._ui_version()}))
|
||||
return
|
||||
|
||||
self.send_error(404)
|
||||
|
||||
def log_message(self, format: str, *args: object) -> None:
|
||||
print(f"{self.address_string()} - {format % args}")
|
||||
|
||||
def _webui(self):
|
||||
weather_store_module = import_module("gibil.classes.weather.store")
|
||||
sigen_store_module = import_module("gibil.classes.sigen.store")
|
||||
oracle_store_module = import_module("gibil.classes.oracle.store")
|
||||
oracle_builder_module = import_module("gibil.classes.oracle.builder")
|
||||
oracle_display_module = import_module("gibil.classes.oracle.display")
|
||||
oracle_quality_display_module = import_module(
|
||||
"gibil.classes.oracle.quality_display"
|
||||
)
|
||||
weather_display_module = import_module("gibil.classes.weather.display")
|
||||
webui_module = import_module("gibil.classes.webui")
|
||||
reload(weather_store_module)
|
||||
reload(sigen_store_module)
|
||||
reload(oracle_store_module)
|
||||
reload(oracle_builder_module)
|
||||
reload(oracle_display_module)
|
||||
reload(oracle_quality_display_module)
|
||||
reload(weather_display_module)
|
||||
reload(webui_module)
|
||||
return webui_module.WebUI()
|
||||
|
||||
def _float_param(
|
||||
self,
|
||||
params: dict[str, list[str]],
|
||||
key: str,
|
||||
default: float,
|
||||
) -> float:
|
||||
values = params.get(key)
|
||||
if not values:
|
||||
return default
|
||||
try:
|
||||
return float(values[0])
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
def _ui_version(self) -> str:
|
||||
mtimes = [
|
||||
str(path.stat().st_mtime_ns)
|
||||
for path in WATCHED_PATHS
|
||||
if path.exists()
|
||||
]
|
||||
return ".".join(mtimes)
|
||||
|
||||
def _send_html(self, body: str) -> None:
|
||||
encoded = body.encode("utf-8")
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html; charset=utf-8")
|
||||
self.send_header("Content-Length", str(len(encoded)))
|
||||
self.end_headers()
|
||||
self.wfile.write(encoded)
|
||||
|
||||
def _send_json_text(self, body: str) -> None:
|
||||
encoded = body.encode("utf-8")
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json; charset=utf-8")
|
||||
self.send_header("Cache-Control", "no-store")
|
||||
self.send_header("Content-Length", str(len(encoded)))
|
||||
self.end_headers()
|
||||
self.wfile.write(encoded)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
server = ThreadingHTTPServer((HOST, PORT), AstrapeWebHandler)
|
||||
print(f"Astrape web UI listening on http://{HOST}:{PORT}")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Debug script to find duplicate target times in forecast data."""
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.weather.store import WeatherStore
|
||||
from collections import defaultdict
|
||||
|
||||
EnvLoader().load()
|
||||
|
||||
store = WeatherStore.from_env()
|
||||
dataset = store.load_display_dataset()
|
||||
|
||||
# Group by (target_at, horizon_hours) to find duplicates
|
||||
by_key = defaultdict(list)
|
||||
for point in dataset.forecast_points:
|
||||
key = (point.target_at, point.horizon_hours)
|
||||
by_key[key].append(point)
|
||||
|
||||
# Find duplicates
|
||||
duplicates = {k: v for k, v in by_key.items() if len(v) > 1}
|
||||
|
||||
print(f"\nTotal forecast points: {len(dataset.forecast_points)}")
|
||||
print(f"Unique (target_at, horizon) pairs: {len(by_key)}")
|
||||
print(f"Duplicate (target_at, horizon) pairs: {len(duplicates)}")
|
||||
|
||||
if duplicates:
|
||||
print("\nFirst 3 duplicates:")
|
||||
for (target_at, horizon), points in list(duplicates.items())[:3]:
|
||||
print(f"\n target_at={target_at}, horizon={horizon}h ({len(points)} points):")
|
||||
for i, p in enumerate(points):
|
||||
print(f" [{i}] issued_at={p.issued_at}, temp={p.temperature_c}, source={p.source}")
|
||||
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Debug baseline energy forecast curves."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import timezone
|
||||
|
||||
from gibil.classes.oracle.builder import EnergyForecastBuilder
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.models import PowerForecastPoint
|
||||
|
||||
|
||||
def main() -> None:
|
||||
EnvLoader().load()
|
||||
args = parse_args()
|
||||
solar_run, load_run, net_run = EnergyForecastBuilder.from_env().build()
|
||||
|
||||
print(
|
||||
f"issued_at={net_run.issued_at.astimezone(timezone.utc).isoformat(timespec='seconds')}"
|
||||
)
|
||||
print(
|
||||
f"solar_model={solar_run.model_version} "
|
||||
f"load_model={load_run.model_version} points={len(net_run.points)}"
|
||||
)
|
||||
print(
|
||||
"target_at solar_p10 solar_p50 solar_p90 "
|
||||
"load_p10 load_p50 load_p90 net_p10 net_p50 net_p90"
|
||||
)
|
||||
solar_by_target = _by_target(solar_run.points)
|
||||
load_by_target = _by_target(load_run.points)
|
||||
for point in net_run.points[: args.limit]:
|
||||
solar_point = solar_by_target[point.target_at]
|
||||
load_point = load_by_target[point.target_at]
|
||||
print(
|
||||
f"{point.target_at.astimezone(timezone.utc).isoformat(timespec='minutes'):25} "
|
||||
f"{solar_point.p10_power_w:9.0f} "
|
||||
f"{solar_point.p50_power_w:9.0f} "
|
||||
f"{solar_point.p90_power_w:9.0f} "
|
||||
f"{load_point.p10_power_w:8.0f} "
|
||||
f"{load_point.p50_power_w:8.0f} "
|
||||
f"{load_point.p90_power_w:8.0f} "
|
||||
f"{point.p10_net_power_w:7.0f} "
|
||||
f"{point.p50_net_power_w:7.0f} "
|
||||
f"{point.p90_net_power_w:7.0f}"
|
||||
)
|
||||
|
||||
|
||||
def _by_target(points: list[PowerForecastPoint]) -> dict[object, PowerForecastPoint]:
|
||||
return {point.target_at: point for point in points}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Print baseline solar/load/net forecast curves."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=24,
|
||||
help="Number of forecast points to show. Defaults to 24.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,5 @@
|
||||
from gibil.scripts.oracle_evaluator import main
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,198 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Debug script to inspect stored Sigenergy plant snapshots."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import timezone
|
||||
import json
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.sigen.store import SigenStore
|
||||
|
||||
|
||||
def main() -> None:
|
||||
EnvLoader().load()
|
||||
args = parse_args()
|
||||
store = SigenStore.from_env()
|
||||
|
||||
if args.view == "raw":
|
||||
rows = load_raw_snapshots(store, args.limit)
|
||||
print_raw_snapshots(rows)
|
||||
return
|
||||
|
||||
rows = load_rollup(store, args.view, args.limit)
|
||||
print_rollup(rows, args.view)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Inspect stored Sigenergy plant snapshots."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--view",
|
||||
choices=("raw", "1m", "15m", "1h"),
|
||||
default="raw",
|
||||
help="View to inspect. Defaults to raw.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of most recent rows/buckets to show. Defaults to 10.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_raw_snapshots(store: SigenStore, limit: int) -> list[tuple]:
|
||||
with store._connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT
|
||||
observed_at,
|
||||
received_at,
|
||||
solar_power_w,
|
||||
load_power_w,
|
||||
battery_soc_pct,
|
||||
battery_power_w,
|
||||
grid_import_w,
|
||||
grid_export_w,
|
||||
plant_active_power_w,
|
||||
accumulated_pv_energy_kwh,
|
||||
daily_consumed_energy_kwh,
|
||||
raw_values
|
||||
FROM sigen_plant_snapshots
|
||||
ORDER BY observed_at DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(limit,),
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def load_rollup(store: SigenStore, view: str, limit: int) -> list[tuple]:
|
||||
view_name = {
|
||||
"1m": "sigen_plant_snapshots_1m",
|
||||
"15m": "sigen_plant_snapshots_15m",
|
||||
"1h": "sigen_plant_snapshots_1h",
|
||||
}[view]
|
||||
|
||||
with store._connection() as connection:
|
||||
with connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT
|
||||
bucket,
|
||||
sample_count,
|
||||
avg_solar_power_w,
|
||||
max_solar_power_w,
|
||||
avg_load_power_w,
|
||||
max_load_power_w,
|
||||
avg_grid_import_w,
|
||||
max_grid_import_w,
|
||||
avg_grid_export_w,
|
||||
max_grid_export_w,
|
||||
avg_battery_soc_pct
|
||||
FROM {view_name}
|
||||
ORDER BY bucket DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(limit,),
|
||||
)
|
||||
return cursor.fetchall()
|
||||
|
||||
|
||||
def print_raw_snapshots(rows: list[tuple]) -> None:
|
||||
print(f"raw_snapshots={len(rows)}")
|
||||
for row in rows:
|
||||
(
|
||||
observed_at,
|
||||
received_at,
|
||||
solar_power_w,
|
||||
load_power_w,
|
||||
battery_soc_pct,
|
||||
battery_power_w,
|
||||
grid_import_w,
|
||||
grid_export_w,
|
||||
plant_active_power_w,
|
||||
accumulated_pv_energy_kwh,
|
||||
daily_consumed_energy_kwh,
|
||||
raw_values,
|
||||
) = row
|
||||
print(
|
||||
f"{_fmt_time(observed_at)} "
|
||||
f"solar={_fmt_w(solar_power_w)} "
|
||||
f"load={_fmt_w(load_power_w)} "
|
||||
f"soc={_fmt_pct(battery_soc_pct)} "
|
||||
f"battery={_fmt_w(battery_power_w)} "
|
||||
f"import={_fmt_w(grid_import_w)} "
|
||||
f"export={_fmt_w(grid_export_w)} "
|
||||
f"plant={_fmt_w(plant_active_power_w)} "
|
||||
f"pv_total={_fmt_kwh(accumulated_pv_energy_kwh)} "
|
||||
f"load_today={_fmt_kwh(daily_consumed_energy_kwh)} "
|
||||
f"lag_s={(received_at - observed_at).total_seconds():.1f}"
|
||||
)
|
||||
if raw_values and any(key.endswith("_error") for key in raw_values):
|
||||
errors = {
|
||||
key: value
|
||||
for key, value in raw_values.items()
|
||||
if key.endswith("_error")
|
||||
}
|
||||
print(f" errors={json.dumps(errors, default=str)}")
|
||||
|
||||
|
||||
def print_rollup(rows: list[tuple], view: str) -> None:
|
||||
print(f"{view}_buckets={len(rows)}")
|
||||
for row in rows:
|
||||
(
|
||||
bucket,
|
||||
sample_count,
|
||||
avg_solar_power_w,
|
||||
max_solar_power_w,
|
||||
avg_load_power_w,
|
||||
max_load_power_w,
|
||||
avg_grid_import_w,
|
||||
max_grid_import_w,
|
||||
avg_grid_export_w,
|
||||
max_grid_export_w,
|
||||
avg_battery_soc_pct,
|
||||
) = row
|
||||
print(
|
||||
f"{_fmt_time(bucket)} samples={sample_count:4} "
|
||||
f"solar_avg={_fmt_w(avg_solar_power_w)} "
|
||||
f"solar_max={_fmt_w(max_solar_power_w)} "
|
||||
f"load_avg={_fmt_w(avg_load_power_w)} "
|
||||
f"load_max={_fmt_w(max_load_power_w)} "
|
||||
f"import_avg={_fmt_w(avg_grid_import_w)} "
|
||||
f"import_max={_fmt_w(max_grid_import_w)} "
|
||||
f"export_avg={_fmt_w(avg_grid_export_w)} "
|
||||
f"export_max={_fmt_w(max_grid_export_w)} "
|
||||
f"soc_avg={_fmt_pct(avg_battery_soc_pct)}"
|
||||
)
|
||||
|
||||
|
||||
def _fmt_time(value) -> str:
|
||||
return value.astimezone(timezone.utc).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _fmt_w(value) -> str:
|
||||
if value is None:
|
||||
return "None"
|
||||
return f"{value:.0f}W"
|
||||
|
||||
|
||||
def _fmt_pct(value) -> str:
|
||||
if value is None:
|
||||
return "None"
|
||||
return f"{value:.1f}%"
|
||||
|
||||
|
||||
def _fmt_kwh(value) -> str:
|
||||
if value is None:
|
||||
return "None"
|
||||
return f"{value:.2f}kWh"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,31 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Debug script to trace resolved truth data."""
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.weather.store import WeatherStore
|
||||
from gibil.classes.weather.display import WeatherDisplay
|
||||
from datetime import datetime, timezone
|
||||
|
||||
EnvLoader().load()
|
||||
|
||||
store = WeatherStore.from_env()
|
||||
dataset = store.load_display_dataset()
|
||||
|
||||
print(f"\n=== DEBUG OUTPUT ===")
|
||||
print(f"Forecast points: {len(dataset.forecast_points)}")
|
||||
print(f"Resolved truth points: {len(dataset.resolved_truth)}")
|
||||
|
||||
print(f"\nResolved truth details:")
|
||||
for i, point in enumerate(dataset.resolved_truth):
|
||||
print(f" [{i}] resolved_at={point.resolved_at}, temp={point.temperature_c}, radiation={point.shortwave_radiation_w_m2}")
|
||||
|
||||
print(f"\nAPI payload:")
|
||||
display = WeatherDisplay()
|
||||
payload = display.data_payload(dataset)
|
||||
import json
|
||||
data = json.loads(payload)
|
||||
print(f" Resolved truth in payload: {len(data['resolved_truth'])}")
|
||||
for i, point in enumerate(data['resolved_truth']):
|
||||
print(f" [{i}] resolved_at={point['resolved_at']}, temp={point['temperature_c']}")
|
||||
|
||||
print(f"\n=== END DEBUG ===\n")
|
||||
@@ -0,0 +1,383 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Explore a Sigenergy plant or inverter over Modbus TCP."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from dataclasses import asdict
|
||||
from os import environ
|
||||
|
||||
from gibil.classes.sigen.builder import SigenPlantClient
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.sigen.modbus import (
|
||||
ModbusReadError,
|
||||
ModbusReadResult,
|
||||
RegisterKind,
|
||||
SigenModbusClient,
|
||||
)
|
||||
from gibil.classes.sigen.registers import (
|
||||
DEFAULT_INVERTER_REGISTER_NAMES,
|
||||
DEFAULT_PLANT_REGISTER_NAMES,
|
||||
INVERTER_REGISTERS,
|
||||
PLANT_PARAMETER_REGISTERS,
|
||||
PLANT_REGISTERS,
|
||||
SigenRegister,
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_KINDS: tuple[RegisterKind, ...] = ("holding", "input")
|
||||
ALL_KINDS: tuple[RegisterKind, ...] = ("holding", "input", "coil", "discrete")
|
||||
DEFAULT_UNIT_CANDIDATES = (0, 1, 2, 3, 247, 255)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
EnvLoader().load()
|
||||
args = parse_args()
|
||||
|
||||
if args.command == "units":
|
||||
results = probe_units(args)
|
||||
print_results(results, errors=True)
|
||||
return
|
||||
if args.command == "catalog":
|
||||
if args.group in {"plant", "all"}:
|
||||
print_catalog("Plant Sensors", PLANT_REGISTERS)
|
||||
if args.group in {"params", "all"}:
|
||||
print_catalog("Plant Parameters", PLANT_PARAMETER_REGISTERS)
|
||||
return
|
||||
if args.command == "snapshot":
|
||||
snapshot = SigenPlantClient.from_env().fetch_snapshot()
|
||||
print(json.dumps(asdict(snapshot), indent=2, default=str))
|
||||
return
|
||||
|
||||
with SigenModbusClient(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
unit_id=args.unit_id,
|
||||
timeout=args.timeout,
|
||||
retries=args.retries,
|
||||
trace=args.trace,
|
||||
) as client:
|
||||
if args.command == "probe":
|
||||
print(
|
||||
f"Connected to {args.host}:{args.port} "
|
||||
f"with unit id {args.unit_id}"
|
||||
)
|
||||
return
|
||||
if args.command == "plant":
|
||||
print_known_registers(client, args.register, PLANT_REGISTERS)
|
||||
return
|
||||
if args.command == "inverter":
|
||||
print_known_registers(client, args.register, INVERTER_REGISTERS)
|
||||
return
|
||||
if args.command == "read":
|
||||
try:
|
||||
result = client.read(args.kind, args.address, args.count)
|
||||
print_results([result], errors=True)
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"{args.kind:8} {args.address:5} "
|
||||
f"+{args.count:<3} ERROR {exc}"
|
||||
)
|
||||
return
|
||||
|
||||
results: list[ModbusReadResult | ModbusReadError] = []
|
||||
for kind in args.kind:
|
||||
results.extend(
|
||||
client.scan(
|
||||
kind=kind,
|
||||
start=args.start,
|
||||
count=args.count,
|
||||
chunk_size=args.chunk_size,
|
||||
)
|
||||
)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps([asdict(result) for result in results], indent=2))
|
||||
else:
|
||||
print_results(results, errors=args.errors)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Minimal Modbus TCP explorer for a Sigenergy plant/inverter."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
default=environ.get("SIGEN_MODBUS_HOST"),
|
||||
required="SIGEN_MODBUS_HOST" not in environ,
|
||||
help="Modbus TCP host or IP. Can also be set as SIGEN_MODBUS_HOST.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(environ.get("SIGEN_MODBUS_PORT", "502")),
|
||||
help="Modbus TCP port. Defaults to 502.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unit-id",
|
||||
type=int,
|
||||
default=int(environ.get("SIGEN_MODBUS_UNIT_ID", "1")),
|
||||
help="Modbus unit/slave id. Defaults to 1.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=float,
|
||||
default=float(environ.get("SIGEN_MODBUS_TIMEOUT", "5")),
|
||||
help="Socket timeout in seconds. Defaults to 5.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--retries",
|
||||
type=int,
|
||||
default=int(environ.get("SIGEN_MODBUS_RETRIES", "3")),
|
||||
help="Modbus request retries. Defaults to 3.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trace",
|
||||
action="store_true",
|
||||
help="Print Modbus TCP packet bytes to stderr.",
|
||||
)
|
||||
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
subparsers.add_parser("probe", help="Open a connection and report success.")
|
||||
subparsers.add_parser(
|
||||
"snapshot",
|
||||
help="Read core plant metrics and print the builder snapshot as JSON.",
|
||||
)
|
||||
|
||||
catalog = subparsers.add_parser(
|
||||
"catalog",
|
||||
help="List known Sigenergy plant sensors and settable parameters.",
|
||||
)
|
||||
catalog.add_argument(
|
||||
"group",
|
||||
choices=("plant", "params", "all"),
|
||||
nargs="?",
|
||||
default="all",
|
||||
help="Catalog group to list. Defaults to all.",
|
||||
)
|
||||
|
||||
units = subparsers.add_parser(
|
||||
"units",
|
||||
help="Try small reads against likely unit ids.",
|
||||
)
|
||||
units.add_argument(
|
||||
"--candidate",
|
||||
action="append",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"Unit id candidate to try. Repeat for multiple ids. "
|
||||
"Defaults to 0, 1, 2, 3, 247, and 255."
|
||||
),
|
||||
)
|
||||
units.add_argument(
|
||||
"--kind",
|
||||
action="append",
|
||||
choices=ALL_KINDS,
|
||||
default=None,
|
||||
help=(
|
||||
"Register table to test. Repeat for multiple kinds. "
|
||||
"Defaults to holding and input."
|
||||
),
|
||||
)
|
||||
units.add_argument(
|
||||
"--address",
|
||||
type=int,
|
||||
default=30000,
|
||||
help="Address to test. Defaults to 30000.",
|
||||
)
|
||||
units.add_argument(
|
||||
"--count",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of values to request. Defaults to 1.",
|
||||
)
|
||||
|
||||
plant = subparsers.add_parser(
|
||||
"plant",
|
||||
help="Read a small set of known Sigenergy plant registers.",
|
||||
)
|
||||
plant.add_argument(
|
||||
"--register",
|
||||
action="append",
|
||||
choices=sorted(PLANT_REGISTERS),
|
||||
default=None,
|
||||
help="Known plant register to read. Repeat for multiple registers.",
|
||||
)
|
||||
|
||||
inverter = subparsers.add_parser(
|
||||
"inverter",
|
||||
help="Read a small set of known Sigenergy inverter registers.",
|
||||
)
|
||||
inverter.add_argument(
|
||||
"--register",
|
||||
action="append",
|
||||
choices=sorted(INVERTER_REGISTERS),
|
||||
default=None,
|
||||
help="Known inverter register to read. Repeat for multiple registers.",
|
||||
)
|
||||
|
||||
read = subparsers.add_parser(
|
||||
"read",
|
||||
help="Read one raw Modbus register range.",
|
||||
)
|
||||
read.add_argument(
|
||||
"kind",
|
||||
choices=ALL_KINDS,
|
||||
help="Register table to read.",
|
||||
)
|
||||
read.add_argument(
|
||||
"address",
|
||||
type=int,
|
||||
help="Modbus address to read.",
|
||||
)
|
||||
read.add_argument(
|
||||
"count",
|
||||
type=int,
|
||||
nargs="?",
|
||||
default=1,
|
||||
help="Number of values to read. Defaults to 1.",
|
||||
)
|
||||
|
||||
scan = subparsers.add_parser("scan", help="Scan register ranges in chunks.")
|
||||
scan.add_argument(
|
||||
"--kind",
|
||||
action="append",
|
||||
choices=ALL_KINDS,
|
||||
default=None,
|
||||
help=(
|
||||
"Register table to scan. Repeat for multiple kinds. "
|
||||
"Defaults to holding and input."
|
||||
),
|
||||
)
|
||||
scan.add_argument(
|
||||
"--start",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Starting zero-based Modbus address. Defaults to 0.",
|
||||
)
|
||||
scan.add_argument(
|
||||
"--count",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of addresses to scan. Defaults to 100.",
|
||||
)
|
||||
scan.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Addresses per Modbus request. Defaults to 10.",
|
||||
)
|
||||
scan.add_argument(
|
||||
"--errors",
|
||||
action="store_true",
|
||||
help="Show failed chunks as well as successful reads.",
|
||||
)
|
||||
scan.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print raw result objects as JSON.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.command == "scan" and args.kind is None:
|
||||
args.kind = list(DEFAULT_KINDS)
|
||||
if args.command == "units":
|
||||
if args.kind is None:
|
||||
args.kind = list(DEFAULT_KINDS)
|
||||
if args.candidate is None:
|
||||
args.candidate = list(DEFAULT_UNIT_CANDIDATES)
|
||||
if args.command == "plant" and args.register is None:
|
||||
args.register = list(DEFAULT_PLANT_REGISTER_NAMES)
|
||||
if args.command == "inverter" and args.register is None:
|
||||
args.register = list(DEFAULT_INVERTER_REGISTER_NAMES)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def probe_units(args: argparse.Namespace) -> list[ModbusReadResult | ModbusReadError]:
|
||||
results: list[ModbusReadResult | ModbusReadError] = []
|
||||
for unit_id in args.candidate:
|
||||
with SigenModbusClient(
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
unit_id=unit_id,
|
||||
timeout=args.timeout,
|
||||
retries=args.retries,
|
||||
trace=args.trace,
|
||||
) as client:
|
||||
for kind in args.kind:
|
||||
try:
|
||||
result = client.read(kind, args.address, args.count)
|
||||
results.append(result)
|
||||
except Exception as exc:
|
||||
results.append(
|
||||
ModbusReadError(
|
||||
kind=kind,
|
||||
address=args.address,
|
||||
count=args.count,
|
||||
error=f"unit {unit_id}: {exc}",
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def print_known_registers(
|
||||
client: SigenModbusClient,
|
||||
register_names: list[str],
|
||||
registers: dict[str, SigenRegister],
|
||||
) -> None:
|
||||
for register_name in register_names:
|
||||
register = registers[register_name]
|
||||
try:
|
||||
result = client.read(register.kind, register.address, register.count)
|
||||
value = register.decode(result.values)
|
||||
unit = f" {register.unit}" if register.unit else ""
|
||||
raw_values = " ".join(str(value) for value in result.values)
|
||||
print(
|
||||
f"{register.name:32} {value}{unit:4} "
|
||||
f"({register.kind} {register.address} +{register.count}: {raw_values})"
|
||||
)
|
||||
except Exception as exc:
|
||||
print(
|
||||
f"{register.name:32} ERROR "
|
||||
f"({register.kind} {register.address} +{register.count}: {exc})"
|
||||
)
|
||||
|
||||
|
||||
def print_catalog(title: str, registers: dict[str, SigenRegister]) -> None:
|
||||
print(title)
|
||||
print("-" * len(title))
|
||||
for register in registers.values():
|
||||
unit = register.unit or ""
|
||||
description = register.description or ""
|
||||
print(
|
||||
f"{register.name:48} {register.kind:7} "
|
||||
f"{register.address:5} +{register.count:<2} "
|
||||
f"{register.data_type:6} gain={register.gain:<7g} "
|
||||
f"{unit:5} {description}"
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def print_results(
|
||||
results: list[ModbusReadResult | ModbusReadError],
|
||||
errors: bool,
|
||||
) -> None:
|
||||
for result in results:
|
||||
if isinstance(result, ModbusReadError):
|
||||
if errors:
|
||||
print(
|
||||
f"{result.kind:8} {result.address:5} "
|
||||
f"+{result.count:<3} ERROR {result.error}"
|
||||
)
|
||||
continue
|
||||
|
||||
values = " ".join(str(value) for value in result.values)
|
||||
print(f"{result.kind:8} {result.address:5} +{result.count:<3} {values}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import timedelta
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.oracle.store import OracleStore
|
||||
|
||||
|
||||
def main() -> None:
|
||||
EnvLoader().load()
|
||||
args = parse_args()
|
||||
store = OracleStore.from_env()
|
||||
|
||||
if args.evaluate:
|
||||
evaluated_count = store.evaluate_due_forecasts(
|
||||
actual_window=timedelta(minutes=args.actual_window_minutes),
|
||||
lookback=timedelta(hours=args.lookback_hours),
|
||||
limit=args.limit,
|
||||
)
|
||||
print(f"evaluated_oracle_forecasts={evaluated_count}")
|
||||
|
||||
if args.summary:
|
||||
rows = store.load_evaluation_summary(
|
||||
lookback=timedelta(hours=args.lookback_hours)
|
||||
)
|
||||
print(f"oracle_evaluation_summary_rows={len(rows)}")
|
||||
for row in rows:
|
||||
print(
|
||||
" ".join(
|
||||
[
|
||||
f"kind={row['kind']}",
|
||||
f"model={row['model_version']}",
|
||||
f"horizon={row.get('horizon_label') or _format_horizon(row)}",
|
||||
f"n={row['evaluated_count']}",
|
||||
f"bias={_format_w(row['mean_error_w'])}",
|
||||
f"mae={_format_w(row['mean_absolute_error_w'])}",
|
||||
f"median_ae={_format_w(row['median_absolute_error_w'])}",
|
||||
f"mape={_format_pct(row['mean_absolute_pct_error'])}",
|
||||
f"coverage={_format_pct(row['interval_coverage'])}",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate stored Astrape oracle predictions against Sigen actuals."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evaluate",
|
||||
action="store_true",
|
||||
help="Evaluate stored predictions whose target time has passed.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--summary",
|
||||
action="store_true",
|
||||
help="Print evaluation quality by kind/model/horizon bucket.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--actual-window-minutes",
|
||||
type=float,
|
||||
default=5,
|
||||
help="Minutes after each target timestamp to average as realized actuals.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lookback-hours",
|
||||
type=float,
|
||||
default=168,
|
||||
help="Only evaluate/summarize predictions with target times this recent.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Maximum unevaluated predictions to process.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if not args.evaluate and not args.summary:
|
||||
args.evaluate = True
|
||||
args.summary = True
|
||||
return args
|
||||
|
||||
|
||||
def _format_w(value: object) -> str:
|
||||
if value is None:
|
||||
return "n/a"
|
||||
return f"{float(value):.0f}W"
|
||||
|
||||
|
||||
def _format_horizon(row: dict[str, object]) -> str:
|
||||
return f"{row['min_horizon_minutes']}-{row['max_horizon_minutes']}m"
|
||||
|
||||
|
||||
def _format_pct(value: object) -> str:
|
||||
if value is None:
|
||||
return "n/a"
|
||||
return f"{float(value) * 100:.1f}%"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1 @@
|
||||
"""TCN training and inspection scripts."""
|
||||
@@ -0,0 +1,254 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from random import Random
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
from gibil.classes.predictors.usage_hybrid_tcn import (
|
||||
UsageHybridTCNConfig,
|
||||
build_usage_hybrid_tcn,
|
||||
pinball_loss,
|
||||
)
|
||||
from gibil.classes.predictors.usage_sequence_dataset import (
|
||||
UsageSequenceExample,
|
||||
UsageSequenceDatasetBuilder,
|
||||
UsageSequenceDatasetConfig,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
EnvLoader().load()
|
||||
args = parse_args()
|
||||
config = UsageSequenceDatasetConfig.from_env()
|
||||
builder = UsageSequenceDatasetBuilder(config=config)
|
||||
examples = builder.build(limit=args.limit)
|
||||
|
||||
print(f"usage_sequence_examples={len(examples)}")
|
||||
print(
|
||||
"minimum_history_hours="
|
||||
f"{builder.max_past_hours + config.future_hours}"
|
||||
)
|
||||
print(f"past_features={len(builder.past_feature_names)}")
|
||||
for scale in config.past_scales:
|
||||
print(
|
||||
f"past_scale={scale.name} "
|
||||
f"hours={scale.hours} "
|
||||
f"step_seconds={scale.step_seconds} "
|
||||
f"steps={builder.past_steps(scale)}"
|
||||
)
|
||||
print(f"future_steps={builder.future_steps}")
|
||||
print(f"future_features={len(builder.future_feature_names)}")
|
||||
|
||||
if examples:
|
||||
first = examples[0]
|
||||
last = examples[-1]
|
||||
print(f"first_issued_at={first.issued_at.isoformat()}")
|
||||
print(f"last_issued_at={last.issued_at.isoformat()}")
|
||||
for name, rows in first.past_by_scale.items():
|
||||
print(f"first_past_{name}_shape={len(rows)}x{len(rows[0])}")
|
||||
token_count = sum(
|
||||
len(tokens)
|
||||
for tokens in first.past_tokens_by_scale[name]
|
||||
)
|
||||
print(f"first_past_{name}_tokens={token_count}")
|
||||
print(
|
||||
"first_future_feature_shape="
|
||||
f"{len(first.future_features)}x{len(first.future_features[0])}"
|
||||
)
|
||||
print(
|
||||
"first_future_tokens="
|
||||
f"{sum(len(tokens) for tokens in first.future_tokens)}"
|
||||
)
|
||||
print(f"first_targets={len(first.targets)}")
|
||||
print(
|
||||
"first_target_preview="
|
||||
+ ",".join(f"{value:.0f}" for value in first.targets[:8])
|
||||
)
|
||||
|
||||
if args.dry_run:
|
||||
return
|
||||
|
||||
if not examples:
|
||||
raise SystemExit("No usage sequence examples available for training yet.")
|
||||
|
||||
train_model(
|
||||
examples=examples,
|
||||
builder=builder,
|
||||
epochs=args.epochs,
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
artifact_path=args.artifact_path,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build training windows for the sequence usage oracle."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Only build examples and print dataset shape.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Optional maximum number of examples to build.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Training epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Training batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Adam learning rate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--artifact-path",
|
||||
type=Path,
|
||||
default=Path("models/usage_sequence_tcn_v1.pt"),
|
||||
help="Where to save the trained TCN artifact.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=7,
|
||||
help="Deterministic shuffle seed.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def train_model(
|
||||
examples: list[UsageSequenceExample],
|
||||
builder: UsageSequenceDatasetBuilder,
|
||||
epochs: int,
|
||||
batch_size: int,
|
||||
learning_rate: float,
|
||||
artifact_path: Path,
|
||||
seed: int,
|
||||
) -> None:
|
||||
try:
|
||||
import torch
|
||||
except ImportError as error:
|
||||
raise SystemExit(
|
||||
"PyTorch is required for training. Install it with "
|
||||
"`python3 -m pip install -r requirements.txt`."
|
||||
) from error
|
||||
|
||||
torch.backends.mkldnn.enabled = False
|
||||
if hasattr(torch.backends, "nnpack"):
|
||||
torch.backends.nnpack.enabled = False
|
||||
scale_names = tuple(scale.name for scale in builder.config.past_scales)
|
||||
model_config = UsageHybridTCNConfig(
|
||||
past_feature_count=len(builder.past_feature_names),
|
||||
future_feature_count=len(builder.future_feature_names),
|
||||
future_steps=builder.future_steps,
|
||||
scale_names=scale_names,
|
||||
)
|
||||
model = build_usage_hybrid_tcn(model_config)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
||||
|
||||
shuffled = examples[:]
|
||||
Random(seed).shuffle(shuffled)
|
||||
validation_count = max(1, len(shuffled) // 5) if len(shuffled) >= 5 else 0
|
||||
validation_examples = shuffled[:validation_count]
|
||||
training_examples = shuffled[validation_count:] or shuffled
|
||||
|
||||
for epoch in range(1, epochs + 1):
|
||||
model.train()
|
||||
training_losses = []
|
||||
for batch in batches(training_examples, batch_size):
|
||||
past_by_scale, future_features, targets = examples_to_tensors(
|
||||
batch,
|
||||
scale_names,
|
||||
torch,
|
||||
)
|
||||
prediction = model(past_by_scale, future_features)
|
||||
loss = pinball_loss(prediction, targets, model_config.quantiles)
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
training_losses.append(float(loss.detach()))
|
||||
|
||||
validation_loss = None
|
||||
if validation_examples:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
validation_losses = []
|
||||
for batch in batches(validation_examples, batch_size):
|
||||
past_by_scale, future_features, targets = examples_to_tensors(
|
||||
batch,
|
||||
scale_names,
|
||||
torch,
|
||||
)
|
||||
prediction = model(past_by_scale, future_features)
|
||||
loss = pinball_loss(prediction, targets, model_config.quantiles)
|
||||
validation_losses.append(float(loss.detach()))
|
||||
validation_loss = sum(validation_losses) / len(validation_losses)
|
||||
|
||||
train_loss = sum(training_losses) / len(training_losses)
|
||||
message = f"epoch={epoch} train_pinball_loss={train_loss:.4f}"
|
||||
if validation_loss is not None:
|
||||
message += f" validation_pinball_loss={validation_loss:.4f}"
|
||||
print(message)
|
||||
|
||||
artifact_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(
|
||||
{
|
||||
"model_version": "sequence_usage_tcn_v1",
|
||||
"model_config": model_config.__dict__,
|
||||
"past_feature_names": builder.past_feature_names,
|
||||
"future_feature_names": builder.future_feature_names,
|
||||
"state_dict": model.state_dict(),
|
||||
},
|
||||
artifact_path,
|
||||
)
|
||||
print(f"saved_artifact={artifact_path}")
|
||||
|
||||
|
||||
def examples_to_tensors(
|
||||
examples: list[UsageSequenceExample],
|
||||
scale_names: tuple[str, ...],
|
||||
torch,
|
||||
):
|
||||
past_by_scale = {
|
||||
name: torch.tensor(
|
||||
[example.past_by_scale[name] for example in examples],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
for name in scale_names
|
||||
}
|
||||
future_features = torch.tensor(
|
||||
[example.future_features for example in examples],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
targets = torch.tensor(
|
||||
[example.targets for example in examples],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return past_by_scale, future_features, targets
|
||||
|
||||
|
||||
def batches(
|
||||
examples: list[UsageSequenceExample],
|
||||
batch_size: int,
|
||||
):
|
||||
for start in range(0, len(examples), batch_size):
|
||||
yield examples[start : start + batch_size]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,84 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from importlib import import_module, reload
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from gibil.classes.env_loader import EnvLoader
|
||||
|
||||
EnvLoader().load()
|
||||
|
||||
HOST = environ.get("ASTRAPE_WEB_HOST", "0.0.0.0")
|
||||
PORT = int(environ.get("ASTRAPE_WEB_PORT", "8080"))
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
WATCHED_PATHS = [
|
||||
PROJECT_ROOT / "gibil" / "classes" / "webui.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "weather_display.py",
|
||||
PROJECT_ROOT / "gibil" / "classes" / "weather_store.py",
|
||||
]
|
||||
|
||||
|
||||
class AstrapeWebHandler(BaseHTTPRequestHandler):
|
||||
def do_GET(self) -> None:
|
||||
if self.path == "/":
|
||||
self._send_html(self._webui().render_page())
|
||||
return
|
||||
|
||||
if self.path == "/api/weather":
|
||||
self._send_json_text(self._webui().weather_payload())
|
||||
return
|
||||
|
||||
if self.path == "/api/ui-version":
|
||||
self._send_json_text(json.dumps({"version": self._ui_version()}))
|
||||
return
|
||||
|
||||
self.send_error(404)
|
||||
|
||||
def log_message(self, format: str, *args: object) -> None:
|
||||
print(f"{self.address_string()} - {format % args}")
|
||||
|
||||
def _webui(self):
|
||||
weather_store_module = import_module("gibil.classes.weather_store")
|
||||
weather_display_module = import_module("gibil.classes.weather_display")
|
||||
webui_module = import_module("gibil.classes.webui")
|
||||
reload(weather_store_module)
|
||||
reload(weather_display_module)
|
||||
reload(webui_module)
|
||||
return webui_module.WebUI()
|
||||
|
||||
def _ui_version(self) -> str:
|
||||
mtimes = [
|
||||
str(path.stat().st_mtime_ns)
|
||||
for path in WATCHED_PATHS
|
||||
if path.exists()
|
||||
]
|
||||
return ".".join(mtimes)
|
||||
|
||||
def _send_html(self, body: str) -> None:
|
||||
encoded = body.encode("utf-8")
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "text/html; charset=utf-8")
|
||||
self.send_header("Content-Length", str(len(encoded)))
|
||||
self.end_headers()
|
||||
self.wfile.write(encoded)
|
||||
|
||||
def _send_json_text(self, body: str) -> None:
|
||||
encoded = body.encode("utf-8")
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/json; charset=utf-8")
|
||||
self.send_header("Cache-Control", "no-store")
|
||||
self.send_header("Content-Length", str(len(encoded)))
|
||||
self.end_headers()
|
||||
self.wfile.write(encoded)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
server = ThreadingHTTPServer((HOST, PORT), AstrapeWebHandler)
|
||||
print(f"Astrape web UI listening on http://{HOST}:{PORT}")
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user