from __future__ import annotations from dataclasses import dataclass @dataclass(frozen=True) class UsageHybridTCNConfig: past_feature_count: int future_feature_count: int future_steps: int scale_names: tuple[str, ...] hidden_channels: int = 64 branch_layers: int = 4 dropout: float = 0.10 quantiles: tuple[float, ...] = (0.10, 0.50, 0.90) def build_usage_hybrid_tcn(config: UsageHybridTCNConfig): try: return _build_usage_hybrid_tcn(config) except ImportError as error: raise RuntimeError( "PyTorch is required for TCN training. Install dependencies with " "`python3 -m pip install -r requirements.txt`." ) from error def _build_usage_hybrid_tcn(config: UsageHybridTCNConfig): import torch from torch import nn class CausalTrim(nn.Module): def __init__(self, trim: int) -> None: super().__init__() self.trim = trim def forward(self, value): if self.trim <= 0: return value return value[:, :, :-self.trim] class TemporalBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, dilation: int, dropout: float, ) -> None: super().__init__() padding = (kernel_size - 1) * dilation self.net = nn.Sequential( nn.Conv1d( in_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=padding, ), CausalTrim(padding), nn.ReLU(), nn.Dropout(dropout), nn.Conv1d( out_channels, out_channels, kernel_size=kernel_size, dilation=dilation, padding=padding, ), CausalTrim(padding), nn.ReLU(), nn.Dropout(dropout), ) self.residual = ( nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity() ) self.activation = nn.ReLU() def forward(self, value): return self.activation(self.net(value) + self.residual(value)) class TemporalBranch(nn.Module): def __init__(self) -> None: super().__init__() layers = [] channels = config.past_feature_count for layer_index in range(config.branch_layers): layers.append( TemporalBlock( in_channels=channels, out_channels=config.hidden_channels, kernel_size=5, dilation=2**layer_index, dropout=config.dropout, ) ) channels = config.hidden_channels self.net = nn.Sequential(*layers) def forward(self, value): # Dataset tensors are batch x time x features; Conv1d wants batch x features x time. encoded = self.net(value.transpose(1, 2)) return encoded[:, :, -1] class UsageHybridTCN(nn.Module): def __init__(self) -> None: super().__init__() self.branches = nn.ModuleDict( {name: TemporalBranch() for name in config.scale_names} ) branch_width = config.hidden_channels * len(config.scale_names) self.context = nn.Sequential( nn.Linear(branch_width, config.hidden_channels), nn.ReLU(), nn.Dropout(config.dropout), ) self.future_encoder = nn.Sequential( nn.Linear(config.future_feature_count, config.hidden_channels), nn.ReLU(), ) self.head = nn.Sequential( nn.Linear(config.hidden_channels * 2, config.hidden_channels), nn.ReLU(), nn.Dropout(config.dropout), nn.Linear(config.hidden_channels, len(config.quantiles)), ) def forward(self, past_by_scale, future_features): branch_outputs = [ self.branches[name](past_by_scale[name]) for name in config.scale_names ] context = self.context(torch.cat(branch_outputs, dim=1)) future = self.future_encoder(future_features) repeated_context = context.unsqueeze(1).expand(-1, future.size(1), -1) return self.head(torch.cat([repeated_context, future], dim=2)) return UsageHybridTCN() def pinball_loss(prediction, target, quantiles: tuple[float, ...]): try: import torch except ImportError as error: raise RuntimeError( "PyTorch is required for TCN training. Install dependencies with " "`python3 -m pip install -r requirements.txt`." ) from error target = target.unsqueeze(-1) losses = [] for index, quantile in enumerate(quantiles): error = target - prediction[:, :, index : index + 1] losses.append(torch.maximum(quantile * error, (quantile - 1) * error)) return torch.stack(losses, dim=-1).mean()