from __future__ import annotations import argparse from pathlib import Path from random import Random from gibil.classes.env_loader import EnvLoader from gibil.classes.predictors.usage_hybrid_tcn import ( UsageHybridTCNConfig, build_usage_hybrid_tcn, pinball_loss, ) from gibil.classes.predictors.usage_sequence_dataset import ( UsageSequenceExample, UsageSequenceDatasetBuilder, UsageSequenceDatasetConfig, ) def main() -> None: EnvLoader().load() args = parse_args() config = UsageSequenceDatasetConfig.from_env() builder = UsageSequenceDatasetBuilder(config=config) examples = builder.build(limit=args.limit) print(f"usage_sequence_examples={len(examples)}") print( "minimum_history_hours=" f"{builder.max_past_hours + config.future_hours}" ) print(f"past_features={len(builder.past_feature_names)}") for scale in config.past_scales: print( f"past_scale={scale.name} " f"hours={scale.hours} " f"step_seconds={scale.step_seconds} " f"steps={builder.past_steps(scale)}" ) print(f"future_steps={builder.future_steps}") print(f"future_features={len(builder.future_feature_names)}") if examples: first = examples[0] last = examples[-1] print(f"first_issued_at={first.issued_at.isoformat()}") print(f"last_issued_at={last.issued_at.isoformat()}") for name, rows in first.past_by_scale.items(): print(f"first_past_{name}_shape={len(rows)}x{len(rows[0])}") token_count = sum( len(tokens) for tokens in first.past_tokens_by_scale[name] ) print(f"first_past_{name}_tokens={token_count}") print( "first_future_feature_shape=" f"{len(first.future_features)}x{len(first.future_features[0])}" ) print( "first_future_tokens=" f"{sum(len(tokens) for tokens in first.future_tokens)}" ) print(f"first_targets={len(first.targets)}") print( "first_target_preview=" + ",".join(f"{value:.0f}" for value in first.targets[:8]) ) if args.dry_run: return if not examples: raise SystemExit("No usage sequence examples available for training yet.") train_model( examples=examples, builder=builder, epochs=args.epochs, batch_size=args.batch_size, learning_rate=args.learning_rate, artifact_path=args.artifact_path, seed=args.seed, ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Build training windows for the sequence usage oracle." ) parser.add_argument( "--dry-run", action="store_true", help="Only build examples and print dataset shape.", ) parser.add_argument( "--limit", type=int, default=None, help="Optional maximum number of examples to build.", ) parser.add_argument( "--epochs", type=int, default=20, help="Training epochs.", ) parser.add_argument( "--batch-size", type=int, default=32, help="Training batch size.", ) parser.add_argument( "--learning-rate", type=float, default=0.001, help="Adam learning rate.", ) parser.add_argument( "--artifact-path", type=Path, default=Path("models/usage_sequence_tcn_v1.pt"), help="Where to save the trained TCN artifact.", ) parser.add_argument( "--seed", type=int, default=7, help="Deterministic shuffle seed.", ) return parser.parse_args() def train_model( examples: list[UsageSequenceExample], builder: UsageSequenceDatasetBuilder, epochs: int, batch_size: int, learning_rate: float, artifact_path: Path, seed: int, ) -> None: try: import torch except ImportError as error: raise SystemExit( "PyTorch is required for training. Install it with " "`python3 -m pip install -r requirements.txt`." ) from error torch.backends.mkldnn.enabled = False if hasattr(torch.backends, "nnpack"): torch.backends.nnpack.enabled = False scale_names = tuple(scale.name for scale in builder.config.past_scales) model_config = UsageHybridTCNConfig( past_feature_count=len(builder.past_feature_names), future_feature_count=len(builder.future_feature_names), future_steps=builder.future_steps, scale_names=scale_names, ) model = build_usage_hybrid_tcn(model_config) optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) shuffled = examples[:] Random(seed).shuffle(shuffled) validation_count = max(1, len(shuffled) // 5) if len(shuffled) >= 5 else 0 validation_examples = shuffled[:validation_count] training_examples = shuffled[validation_count:] or shuffled for epoch in range(1, epochs + 1): model.train() training_losses = [] for batch in batches(training_examples, batch_size): past_by_scale, future_features, targets = examples_to_tensors( batch, scale_names, torch, ) prediction = model(past_by_scale, future_features) loss = pinball_loss(prediction, targets, model_config.quantiles) optimizer.zero_grad() loss.backward() optimizer.step() training_losses.append(float(loss.detach())) validation_loss = None if validation_examples: model.eval() with torch.no_grad(): validation_losses = [] for batch in batches(validation_examples, batch_size): past_by_scale, future_features, targets = examples_to_tensors( batch, scale_names, torch, ) prediction = model(past_by_scale, future_features) loss = pinball_loss(prediction, targets, model_config.quantiles) validation_losses.append(float(loss.detach())) validation_loss = sum(validation_losses) / len(validation_losses) train_loss = sum(training_losses) / len(training_losses) message = f"epoch={epoch} train_pinball_loss={train_loss:.4f}" if validation_loss is not None: message += f" validation_pinball_loss={validation_loss:.4f}" print(message) artifact_path.parent.mkdir(parents=True, exist_ok=True) torch.save( { "model_version": "sequence_usage_tcn_v1", "model_config": model_config.__dict__, "past_feature_names": builder.past_feature_names, "future_feature_names": builder.future_feature_names, "state_dict": model.state_dict(), }, artifact_path, ) print(f"saved_artifact={artifact_path}") def examples_to_tensors( examples: list[UsageSequenceExample], scale_names: tuple[str, ...], torch, ): past_by_scale = { name: torch.tensor( [example.past_by_scale[name] for example in examples], dtype=torch.float32, ) for name in scale_names } future_features = torch.tensor( [example.future_features for example in examples], dtype=torch.float32, ) targets = torch.tensor( [example.targets for example in examples], dtype=torch.float32, ) return past_by_scale, future_features, targets def batches( examples: list[UsageSequenceExample], batch_size: int, ): for start in range(0, len(examples), batch_size): yield examples[start : start + batch_size] if __name__ == "__main__": main()