import logging import os import numpy as np import torch import yaml from src.data.containers import BatchTimeSeriesContainer from src.models.model import TimeSeriesModel from src.plotting.plot_timeseries import plot_from_container logger = logging.getLogger(__name__) def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel: """Load the TimeSeriesModel from config and checkpoint.""" with open(config_path) as f: config = yaml.safe_load(f) model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) checkpoint = torch.load(model_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}") return model def plot_with_library( container: BatchTimeSeriesContainer, predictions_np: np.ndarray, # [B, P, N, Q] model_quantiles: list[float] | None, output_dir: str = "outputs", show_plots: bool = True, save_plots: bool = True, ): os.makedirs(output_dir, exist_ok=True) batch_size = container.batch_size for i in range(batch_size): output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None plot_from_container( batch=container, sample_idx=i, predicted_values=predictions_np, model_quantiles=model_quantiles, title=f"Sine Wave Time Series Prediction - Sample {i + 1}", output_file=output_file, show=show_plots, ) def run_inference_and_plot( model: TimeSeriesModel, container: BatchTimeSeriesContainer, output_dir: str = "outputs", use_bfloat16: bool = True, ) -> None: """Run model inference with optional bfloat16 and plot using shared utilities.""" device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu" autocast_enabled = use_bfloat16 and device_type == "cuda" with ( torch.no_grad(), torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled), ): model_output = model(container) preds_full = model_output["result"].to(torch.float32) if hasattr(model, "scaler") and "scale_statistics" in model_output: preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"]) preds_np = preds_full.detach().cpu().numpy() model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None plot_with_library( container=container, predictions_np=preds_np, model_quantiles=model_quantiles, output_dir=output_dir, show_plots=True, save_plots=True, )