|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
|
|
|
import torch |
|
|
from examples.utils import ( |
|
|
load_model, |
|
|
run_inference_and_plot, |
|
|
) |
|
|
from src.data.containers import BatchTimeSeriesContainer |
|
|
from src.synthetic_generation.generator_params import SineWaveGeneratorParams |
|
|
from src.synthetic_generation.sine_waves.sine_wave_generator_wrapper import ( |
|
|
SineWaveGeneratorWrapper, |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main execution function.""" |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Quick start demo for TimeSeriesModel") |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
default="configs/example.yaml", |
|
|
help="Path to model config YAML (default: configs/example.yaml)", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--checkpoint", |
|
|
default="models/checkpoint_38M.pth", |
|
|
help="Path to model checkpoint file (default: models/checkpoint_38M.pth)", |
|
|
) |
|
|
parser.add_argument("--batch_size", type=int, default=3) |
|
|
parser.add_argument("--total_length", type=int, default=2048) |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
parser.add_argument("--output_dir", default="outputs") |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
batch_size = args.batch_size |
|
|
total_length = args.total_length |
|
|
output_dir = args.output_dir |
|
|
seed = args.seed |
|
|
config_path = args.config |
|
|
model_path = args.checkpoint |
|
|
|
|
|
|
|
|
if not os.path.exists(model_path): |
|
|
logger.error(f"Checkpoint file not found at: {model_path}") |
|
|
logger.error( |
|
|
"Please ensure 'checkpoint_38M.pth' is in the root directory (or that you've cloned the repo with Git LFS)." |
|
|
) |
|
|
logger.error("You can also specify a different path using --checkpoint.") |
|
|
return |
|
|
|
|
|
logger.info("=== Time Series Model Demo (Univariate Quantile) ===") |
|
|
|
|
|
|
|
|
sine_params = SineWaveGeneratorParams(global_seed=seed, length=total_length) |
|
|
sine_generator = SineWaveGeneratorWrapper(sine_params) |
|
|
batch = sine_generator.generate_batch(batch_size=batch_size, seed=seed) |
|
|
values = torch.from_numpy(batch.values).to(torch.float32) |
|
|
if values.ndim == 2: |
|
|
values = values.unsqueeze(-1) |
|
|
future_length = 256 |
|
|
history_values = values[:, :-future_length, :] |
|
|
future_values = values[:, -future_length:, :] |
|
|
|
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
raise RuntimeError("CUDA is required to run this demo. No CUDA device detected.") |
|
|
device = torch.device("cuda:0") |
|
|
model = load_model(config_path=config_path, model_path=model_path, device=device) |
|
|
|
|
|
|
|
|
container = BatchTimeSeriesContainer( |
|
|
history_values=history_values.to(device), |
|
|
future_values=future_values.to(device), |
|
|
start=batch.start, |
|
|
frequency=batch.frequency, |
|
|
) |
|
|
|
|
|
|
|
|
run_inference_and_plot(model=model, container=container, output_dir=output_dir, use_bfloat16=True) |
|
|
|
|
|
logger.info("=== Demo completed successfully! ===") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|