File size: 3,394 Bytes
c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import argparse
import logging
import os
import torch
from examples.utils import (
load_model,
run_inference_and_plot,
)
from src.data.containers import BatchTimeSeriesContainer
from src.synthetic_generation.generator_params import SineWaveGeneratorParams
from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import (
SineWaveGeneratorWrapper,
)
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
def main():
"""Main execution function."""
# CLI
parser = argparse.ArgumentParser(description="Quick start demo for TimeSeriesModel")
parser.add_argument(
"--config",
default="configs/example.yaml",
help="Path to model config YAML (default: configs/example.yaml)",
)
parser.add_argument(
"--checkpoint",
default="models/checkpoint_38M.pth",
help="Path to model checkpoint file (default: models/checkpoint_38M.pth)",
)
parser.add_argument("--batch_size", type=int, default=3)
parser.add_argument("--total_length", type=int, default=2048)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--output_dir", default="outputs")
args = parser.parse_args()
# Configuration
batch_size = args.batch_size
total_length = args.total_length
output_dir = args.output_dir
seed = args.seed
config_path = args.config
model_path = args.checkpoint
# Check if the checkpoint file exists
if not os.path.exists(model_path):
logger.error(f"Checkpoint file not found at: {model_path}")
logger.error(
"Please ensure 'checkpoint_38M.pth' is in the root directory (or that you've cloned the repo with Git LFS)."
)
logger.error("You can also specify a different path using --checkpoint.")
return # Exit if no model
logger.info("=== Time Series Model Demo (Univariate Quantile) ===")
# 1) Generate synthetic sine wave data
sine_params = SineWaveGeneratorParams(global_seed=seed, length=total_length)
sine_generator = SineWaveGeneratorWrapper(sine_params)
batch = sine_generator.generate_batch(batch_size=batch_size, seed=seed)
values = torch.from_numpy(batch.values).to(torch.float32)
if values.ndim == 2:
values = values.unsqueeze(-1) # Ensure [B, S, 1] for univariate
future_length = 256
history_values = values[:, :-future_length, :]
future_values = values[:, -future_length:, :]
# 2) Load the pretrained model (CUDA-only). This demo requires a CUDA GPU.
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required to run this demo. No CUDA device detected.")
device = torch.device("cuda:0")
model = load_model(config_path=config_path, model_path=model_path, device=device)
# 3) Pack tensors into the model's input container
container = BatchTimeSeriesContainer(
history_values=history_values.to(device),
future_values=future_values.to(device),
start=batch.start,
frequency=batch.frequency,
)
# 4) Run inference (bfloat16 on CUDA) and plot results
run_inference_and_plot(model=model, container=container, output_dir=output_dir, use_bfloat16=True)
logger.info("=== Demo completed successfully! ===")
if __name__ == "__main__":
main()
|