TempoPFN / src /gift_eval /evaluate.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
import argparse
import logging
import warnings
from pathlib import Path
import matplotlib
from gluonts.model.evaluation import evaluate_model
from gluonts.time_feature import get_seasonality
from linear_operator.utils.cholesky import NumericalWarning
from src.gift_eval.constants import (
DATASET_PROPERTIES,
MED_LONG_DATASETS,
METRICS,
PRETTY_NAMES,
)
from src.gift_eval.core import DatasetMetadata, EvaluationItem, expand_datasets_arg
from src.gift_eval.data import Dataset
from src.gift_eval.predictor import TimeSeriesPredictor
from src.gift_eval.results import write_results_to_disk
from src.plotting.gift_eval_utils import create_plots_for_dataset
logger = logging.getLogger(__name__)
# Warnings configuration
warnings.filterwarnings("ignore", category=NumericalWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
matplotlib.set_loglevel("WARNING")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
logging.getLogger("PIL").setLevel(logging.WARNING)
class WarningFilter(logging.Filter):
def __init__(self, text_to_filter: str) -> None:
super().__init__()
self.text_to_filter = text_to_filter
def filter(self, record: logging.LogRecord) -> bool:
return self.text_to_filter not in record.getMessage()
# Filter out gluonts warnings about mean predictions
gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data"))
def construct_evaluation_data(
dataset_name: str,
dataset_storage_path: str,
terms: list[str] | None = None,
max_windows: int | None = None,
) -> list[tuple[Dataset, DatasetMetadata]]:
"""Build datasets and rich metadata per term for a dataset name."""
if terms is None:
terms = ["short", "medium", "long"]
sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []
if "/" in dataset_name:
ds_key, ds_freq = dataset_name.split("/")
ds_key = ds_key.lower()
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
else:
ds_key = dataset_name.lower()
ds_key = PRETTY_NAMES.get(ds_key, ds_key)
ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")
for term in terms:
# Skip medium/long terms for datasets that don't support them
if (term == "medium" or term == "long") and dataset_name not in MED_LONG_DATASETS:
continue
# Probe once to determine dimensionality
probe_dataset = Dataset(
name=dataset_name,
term=term,
to_univariate=False,
storage_path=dataset_storage_path,
max_windows=max_windows,
)
to_univariate = probe_dataset.target_dim > 1
dataset = Dataset(
name=dataset_name,
term=term,
to_univariate=to_univariate,
storage_path=dataset_storage_path,
max_windows=max_windows,
)
# Compute metadata
season_length = get_seasonality(dataset.freq)
actual_freq = ds_freq if ds_freq else dataset.freq
metadata = DatasetMetadata(
full_name=f"{ds_key}/{actual_freq}/{term}",
key=ds_key,
freq=actual_freq,
term=term,
season_length=season_length,
target_dim=probe_dataset.target_dim,
to_univariate=to_univariate,
prediction_length=dataset.prediction_length,
windows=dataset.windows,
)
sub_datasets.append((dataset, metadata))
return sub_datasets
def evaluate_datasets(
predictor: TimeSeriesPredictor,
dataset: str,
dataset_storage_path: str,
terms: list[str] | None = None,
max_windows: int | None = None,
batch_size: int = 48,
max_context_length: int | None = 1024,
create_plots: bool = False,
max_plots_per_dataset: int = 10,
) -> list[EvaluationItem]:
"""Evaluate predictor on one dataset across the requested terms."""
if terms is None:
terms = ["short", "medium", "long"]
sub_datasets = construct_evaluation_data(
dataset_name=dataset,
dataset_storage_path=dataset_storage_path,
terms=terms,
max_windows=max_windows,
)
results: list[EvaluationItem] = []
for i, (sub_dataset, metadata) in enumerate(sub_datasets):
logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
logger.info(f" Frequency: {sub_dataset.freq}")
logger.info(f" Term: {metadata.term}")
logger.info(f" Prediction length: {sub_dataset.prediction_length}")
logger.info(f" Target dimensions: {sub_dataset.target_dim}")
logger.info(f" Windows: {sub_dataset.windows}")
# Update context on the reusable predictor
predictor.set_dataset_context(
prediction_length=sub_dataset.prediction_length,
freq=sub_dataset.freq,
batch_size=batch_size,
max_context_length=max_context_length,
)
res = evaluate_model(
model=predictor,
test_data=sub_dataset.test_data,
metrics=METRICS,
axis=None,
mask_invalid_label=True,
allow_nan_forecast=False,
seasonality=metadata.season_length,
)
figs: list[tuple[object, str]] = []
if create_plots:
forecasts = predictor.predict(sub_dataset.test_data.input)
figs = create_plots_for_dataset(
forecasts=forecasts,
test_data=sub_dataset.test_data,
dataset_metadata=metadata,
max_plots=max_plots_per_dataset,
max_context_length=max_context_length,
)
results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))
return results
def _run_evaluation(
predictor: TimeSeriesPredictor,
datasets: list[str] | str,
terms: list[str],
dataset_storage_path: str,
max_windows: int | None = None,
batch_size: int = 48,
max_context_length: int | None = 1024,
output_dir: str = "gift_eval_results",
model_name: str = "TimeSeriesModel",
create_plots: bool = False,
max_plots: int = 10,
) -> None:
"""Shared evaluation workflow used by both entry points."""
datasets_to_run = expand_datasets_arg(datasets)
results_root = Path(output_dir)
for ds_name in datasets_to_run:
items = evaluate_datasets(
predictor=predictor,
dataset=ds_name,
dataset_storage_path=dataset_storage_path,
terms=terms,
max_windows=max_windows,
batch_size=batch_size,
max_context_length=max_context_length,
create_plots=create_plots,
max_plots_per_dataset=max_plots,
)
write_results_to_disk(
items=items,
dataset_name=ds_name,
output_dir=results_root,
model_name=model_name,
create_plots=create_plots,
)
def evaluate_from_paths(
model_path: str,
config_path: str,
datasets: list[str] | str,
terms: list[str],
dataset_storage_path: str,
max_windows: int | None = None,
batch_size: int = 48,
max_context_length: int | None = 1024,
output_dir: str = "gift_eval_results",
model_name: str = "TimeSeriesModel",
create_plots: bool = False,
max_plots: int = 10,
) -> None:
"""Entry point: load model from disk and save metrics/plots to disk."""
# Validate inputs early
if not Path(model_path).exists():
raise FileNotFoundError(f"Model path does not exist: {model_path}")
if not Path(config_path).exists():
raise FileNotFoundError(f"Config path does not exist: {config_path}")
predictor = TimeSeriesPredictor.from_paths(
model_path=model_path,
config_path=config_path,
ds_prediction_length=1, # placeholder; set per dataset below
ds_freq="D", # placeholder; set per dataset below
batch_size=batch_size,
max_context_length=max_context_length,
)
_run_evaluation(
predictor=predictor,
datasets=datasets,
terms=terms,
dataset_storage_path=dataset_storage_path,
max_windows=max_windows,
batch_size=batch_size,
max_context_length=max_context_length,
output_dir=output_dir,
model_name=model_name,
create_plots=create_plots,
max_plots=max_plots,
)
def evaluate_in_memory(
model,
config: dict,
datasets: list[str] | str,
terms: list[str],
dataset_storage_path: str,
max_windows: int | None = None,
batch_size: int = 48,
max_context_length: int | None = 1024,
output_dir: str = "gift_eval_results",
model_name: str = "TimeSeriesModel",
create_plots: bool = False,
max_plots: int = 10,
) -> None:
"""Entry point: evaluate in-memory model and return results per dataset."""
predictor = TimeSeriesPredictor.from_model(
model=model,
config=config,
ds_prediction_length=1, # placeholder; set per dataset below
ds_freq="D", # placeholder; set per dataset below
batch_size=batch_size,
max_context_length=max_context_length,
)
_run_evaluation(
predictor=predictor,
datasets=datasets,
terms=terms,
dataset_storage_path=dataset_storage_path,
max_windows=max_windows,
batch_size=batch_size,
max_context_length=max_context_length,
output_dir=output_dir,
model_name=model_name,
create_plots=create_plots,
max_plots=max_plots,
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate TimeSeriesModel on GIFT-Eval datasets")
# Model configuration
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the trained model checkpoint",
)
parser.add_argument(
"--config_path",
type=str,
required=True,
help="Path to the model configuration YAML file",
)
parser.add_argument(
"--model_name",
type=str,
default="TimeSeriesModel",
help="Name identifier for the model",
)
# Dataset configuration
parser.add_argument(
"--datasets",
type=str,
default="all",
help="Comma-separated list of dataset names to evaluate (or 'all')",
)
parser.add_argument(
"--dataset_storage_path",
type=str,
default="/work/dlclarge2/moroshav-GiftEvalPretrain/gift_eval",
help="Path to the dataset storage directory (default: GIFT_EVAL)",
)
parser.add_argument(
"--terms",
type=str,
default="short,medium,long",
help="Comma-separated list of prediction terms to evaluate",
)
parser.add_argument(
"--max_windows",
type=int,
default=None,
help="Maximum number of windows to use for evaluation",
)
# Inference configuration
parser.add_argument("--batch_size", type=int, default=48, help="Batch size for model inference")
parser.add_argument(
"--max_context_length",
type=int,
default=1024,
help="Maximum context length to use (None for no limit)",
)
# Output configuration
parser.add_argument(
"--output_dir",
type=str,
default="gift_eval_results",
help="Directory to save evaluation results",
)
# Plotting configuration
parser.add_argument(
"--create_plots",
action="store_true",
help="Create and save plots for each evaluation window",
)
parser.add_argument(
"--max_plots_per_dataset",
type=int,
default=10,
help="Maximum number of plots to create per dataset term",
)
args = parser.parse_args()
args.terms = args.terms.split(",")
args.datasets = args.datasets.split(",")
return args
def _configure_logging() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
if __name__ == "__main__":
_configure_logging()
args = _parse_args()
logger.info(f"Command Line Arguments: {vars(args)}")
try:
evaluate_from_paths(
model_path=args.model_path,
config_path=args.config_path,
datasets=args.datasets,
terms=args.terms,
dataset_storage_path=args.dataset_storage_path,
max_windows=args.max_windows,
batch_size=args.batch_size,
max_context_length=args.max_context_length,
output_dir=args.output_dir,
model_name=args.model_name,
create_plots=args.create_plots,
max_plots=args.max_plots_per_dataset,
)
except Exception as e:
logger.error(f"Evaluation failed: {str(e)}")
raise