| | from abc import ABC, abstractmethod |
| |
|
| | import torch |
| |
|
| |
|
| | class BaseScaler(ABC): |
| | """ |
| | Abstract base class for time series scalers. |
| | |
| | Defines the interface for scaling multivariate time series data with support |
| | for masked values and channel-wise scaling. |
| | """ |
| |
|
| | @abstractmethod |
| | def compute_statistics( |
| | self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Compute scaling statistics from historical data. |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply scaling transformation to data. |
| | """ |
| | pass |
| |
|
| | @abstractmethod |
| | def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply inverse scaling transformation to recover original scale. |
| | """ |
| | pass |
| |
|
| |
|
| | class RobustScaler(BaseScaler): |
| | """ |
| | Robust scaler using median and IQR for normalization. |
| | """ |
| |
|
| | def __init__(self, epsilon: float = 1e-6, min_scale: float = 1e-3): |
| | if epsilon <= 0: |
| | raise ValueError("epsilon must be positive") |
| | if min_scale <= 0: |
| | raise ValueError("min_scale must be positive") |
| | self.epsilon = epsilon |
| | self.min_scale = min_scale |
| |
|
| | def compute_statistics( |
| | self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Compute median and IQR statistics from historical data with improved numerical stability. |
| | """ |
| | batch_size, seq_len, num_channels = history_values.shape |
| | device = history_values.device |
| |
|
| | medians = torch.zeros(batch_size, 1, num_channels, device=device) |
| | iqrs = torch.ones(batch_size, 1, num_channels, device=device) |
| |
|
| | for b in range(batch_size): |
| | for c in range(num_channels): |
| | channel_data = history_values[b, :, c] |
| |
|
| | if history_mask is not None: |
| | mask = history_mask[b, :].bool() |
| | valid_data = channel_data[mask] |
| | else: |
| | valid_data = channel_data |
| |
|
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | valid_data = valid_data[torch.isfinite(valid_data)] |
| |
|
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | median_val = torch.median(valid_data) |
| | medians[b, 0, c] = median_val |
| |
|
| | if len(valid_data) > 1: |
| | try: |
| | q75 = torch.quantile(valid_data, 0.75) |
| | q25 = torch.quantile(valid_data, 0.25) |
| | iqr_val = q75 - q25 |
| | iqr_val = torch.max(iqr_val, torch.tensor(self.min_scale, device=device)) |
| | iqrs[b, 0, c] = iqr_val |
| | except Exception: |
| | std_val = torch.std(valid_data) |
| | iqrs[b, 0, c] = torch.max(std_val, torch.tensor(self.min_scale, device=device)) |
| | else: |
| | iqrs[b, 0, c] = self.min_scale |
| |
|
| | return {"median": medians, "iqr": iqrs} |
| |
|
| | def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply robust scaling: (data - median) / (iqr + epsilon). |
| | """ |
| | median = statistics["median"] |
| | iqr = statistics["iqr"] |
| |
|
| | denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)) |
| | scaled_data = (data - median) / denominator |
| | scaled_data = torch.clamp(scaled_data, -50.0, 50.0) |
| |
|
| | return scaled_data |
| |
|
| | def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply inverse robust scaling, now compatible with 3D or 4D tensors. |
| | """ |
| | median = statistics["median"] |
| | iqr = statistics["iqr"] |
| |
|
| | denominator = torch.max(iqr + self.epsilon, torch.tensor(self.min_scale, device=iqr.device)) |
| |
|
| | if scaled_data.ndim == 4: |
| | denominator = denominator.unsqueeze(-1) |
| | median = median.unsqueeze(-1) |
| |
|
| | return scaled_data * denominator + median |
| |
|
| |
|
| | class MinMaxScaler(BaseScaler): |
| | """ |
| | Min-Max scaler that normalizes data to the range [-1, 1]. |
| | """ |
| |
|
| | def __init__(self, epsilon: float = 1e-8): |
| | if epsilon <= 0: |
| | raise ValueError("epsilon must be positive") |
| | self.epsilon = epsilon |
| |
|
| | def compute_statistics( |
| | self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Compute min and max statistics from historical data. |
| | """ |
| | batch_size, seq_len, num_channels = history_values.shape |
| | device = history_values.device |
| |
|
| | mins = torch.zeros(batch_size, 1, num_channels, device=device) |
| | maxs = torch.ones(batch_size, 1, num_channels, device=device) |
| |
|
| | for b in range(batch_size): |
| | for c in range(num_channels): |
| | channel_data = history_values[b, :, c] |
| |
|
| | if history_mask is not None: |
| | mask = history_mask[b, :].bool() |
| | valid_data = channel_data[mask] |
| | else: |
| | valid_data = channel_data |
| |
|
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | min_val = torch.min(valid_data) |
| | max_val = torch.max(valid_data) |
| |
|
| | mins[b, 0, c] = min_val |
| | maxs[b, 0, c] = max_val |
| |
|
| | if torch.abs(max_val - min_val) < self.epsilon: |
| | maxs[b, 0, c] = min_val + 1.0 |
| |
|
| | return {"min": mins, "max": maxs} |
| |
|
| | def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply min-max scaling to range [-1, 1]. |
| | """ |
| | min_val = statistics["min"] |
| | max_val = statistics["max"] |
| |
|
| | normalized = (data - min_val) / (max_val - min_val + self.epsilon) |
| | return normalized * 2.0 - 1.0 |
| |
|
| | def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply inverse min-max scaling, now compatible with 3D or 4D tensors. |
| | """ |
| | min_val = statistics["min"] |
| | max_val = statistics["max"] |
| |
|
| | if scaled_data.ndim == 4: |
| | min_val = min_val.unsqueeze(-1) |
| | max_val = max_val.unsqueeze(-1) |
| |
|
| | normalized = (scaled_data + 1.0) / 2.0 |
| | return normalized * (max_val - min_val + self.epsilon) + min_val |
| |
|
| |
|
| | class MeanScaler(BaseScaler): |
| | """ |
| | A scaler that centers the data by subtracting the channel-wise mean. |
| | |
| | This scaler only performs centering and does not affect the scale of the data. |
| | """ |
| |
|
| | def compute_statistics( |
| | self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Compute the mean for each channel from historical data. |
| | """ |
| | batch_size, seq_len, num_channels = history_values.shape |
| | device = history_values.device |
| |
|
| | |
| | means = torch.zeros(batch_size, 1, num_channels, device=device) |
| |
|
| | for b in range(batch_size): |
| | for c in range(num_channels): |
| | channel_data = history_values[b, :, c] |
| |
|
| | |
| | if history_mask is not None: |
| | mask = history_mask[b, :].bool() |
| | valid_data = channel_data[mask] |
| | else: |
| | valid_data = channel_data |
| |
|
| | |
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | |
| | valid_data = valid_data[torch.isfinite(valid_data)] |
| |
|
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | |
| | means[b, 0, c] = torch.mean(valid_data) |
| |
|
| | return {"mean": means} |
| |
|
| | def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply mean centering: data - mean. |
| | """ |
| | mean = statistics["mean"] |
| | return data - mean |
| |
|
| | def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply inverse mean centering: scaled_data + mean. |
| | |
| | Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors. |
| | """ |
| | mean = statistics["mean"] |
| |
|
| | |
| | if scaled_data.ndim == 4: |
| | mean = mean.unsqueeze(-1) |
| |
|
| | return scaled_data + mean |
| |
|
| |
|
| | class MedianScaler(BaseScaler): |
| | """ |
| | A scaler that centers the data by subtracting the channel-wise median. |
| | |
| | This scaler only performs centering and does not affect the scale of the data. |
| | It is more robust to outliers than the MeanScaler. |
| | """ |
| |
|
| | def compute_statistics( |
| | self, history_values: torch.Tensor, history_mask: torch.Tensor | None = None |
| | ) -> dict[str, torch.Tensor]: |
| | """ |
| | Compute the median for each channel from historical data. |
| | """ |
| | batch_size, seq_len, num_channels = history_values.shape |
| | device = history_values.device |
| |
|
| | |
| | medians = torch.zeros(batch_size, 1, num_channels, device=device) |
| |
|
| | for b in range(batch_size): |
| | for c in range(num_channels): |
| | channel_data = history_values[b, :, c] |
| |
|
| | |
| | if history_mask is not None: |
| | mask = history_mask[b, :].bool() |
| | valid_data = channel_data[mask] |
| | else: |
| | valid_data = channel_data |
| |
|
| | |
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | |
| | valid_data = valid_data[torch.isfinite(valid_data)] |
| |
|
| | if len(valid_data) == 0: |
| | continue |
| |
|
| | |
| | medians[b, 0, c] = torch.median(valid_data) |
| |
|
| | return {"median": medians} |
| |
|
| | def scale(self, data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply median centering: data - median. |
| | """ |
| | median = statistics["median"] |
| | return data - median |
| |
|
| | def inverse_scale(self, scaled_data: torch.Tensor, statistics: dict[str, torch.Tensor]) -> torch.Tensor: |
| | """ |
| | Apply inverse median centering: scaled_data + median. |
| | |
| | Handles both 3D (e.g., training input) and 4D (e.g., model output samples) tensors. |
| | """ |
| | median = statistics["median"] |
| |
|
| | |
| | if scaled_data.ndim == 4: |
| | median = median.unsqueeze(-1) |
| |
|
| | return scaled_data + median |
| |
|