JavadBayazi commited on
Commit
32511bd
·
1 Parent(s): 1a671f5

Upgrade to Chronos-2 with DataFrame API support

Browse files
Files changed (2) hide show
  1. app.py +19 -10
  2. requirements.txt +2 -0
app.py CHANGED
@@ -1,17 +1,17 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
- from chronos import ChronosPipeline
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
 
8
  # Load the Chronos Pipeline model
9
  @st.cache_resource
10
  def load_pipeline():
11
- pipeline = ChronosPipeline.from_pretrained(
12
  "amazon/chronos-2",
13
  device_map="cpu", # Change to CPU
14
- torch_dtype=torch.float32, # Use float32 for CPU
15
  )
16
  return pipeline
17
 
@@ -53,19 +53,28 @@ prediction_length = st.slider("Select Forecast Horizon (Months)", min_value=1, m
53
 
54
  # If data is valid, perform the forecast
55
  if time_series_data:
56
- # Convert the data to a tensor
57
- context = torch.tensor(time_series_data, dtype=torch.float32)
 
 
 
 
58
 
59
- # Make the forecast
60
- forecast = pipeline.predict(
61
- inputs=context,
62
  prediction_length=prediction_length,
63
- num_samples=20,
 
 
 
64
  )
65
 
66
  # Prepare forecast data for plotting
67
  forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length)
68
- low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
 
 
69
 
70
  # Plot the historical and forecasted data
71
  plt.figure(figsize=(8, 4))
 
1
  import streamlit as st
2
  import pandas as pd
3
  import torch
4
+ from chronos import Chronos2Pipeline
5
  import matplotlib.pyplot as plt
6
  import numpy as np
7
 
8
  # Load the Chronos Pipeline model
9
  @st.cache_resource
10
  def load_pipeline():
11
+ pipeline = Chronos2Pipeline.from_pretrained(
12
  "amazon/chronos-2",
13
  device_map="cpu", # Change to CPU
14
+ dtype=torch.float32, # Use float32 for CPU
15
  )
16
  return pipeline
17
 
 
53
 
54
  # If data is valid, perform the forecast
55
  if time_series_data:
56
+ # Create a DataFrame for Chronos-2
57
+ context_df = pd.DataFrame({
58
+ 'timestamp': pd.date_range(start='2020-01-01', periods=len(time_series_data), freq='ME'),
59
+ 'target': time_series_data,
60
+ 'id': 'series_1'
61
+ })
62
 
63
+ # Make the forecast using Chronos-2 API
64
+ pred_df = pipeline.predict_df(
65
+ context_df,
66
  prediction_length=prediction_length,
67
+ quantile_levels=[0.1, 0.5, 0.9],
68
+ id_column="id",
69
+ timestamp_column="timestamp",
70
+ target="target",
71
  )
72
 
73
  # Prepare forecast data for plotting
74
  forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length)
75
+ median = pred_df["predictions"].values
76
+ low = pred_df["0.1"].values
77
+ high = pred_df["0.9"].values
78
 
79
  # Plot the historical and forecasted data
80
  plt.figure(figsize=(8, 4))
requirements.txt CHANGED
@@ -3,3 +3,5 @@ transformers
3
  torch
4
  git+https://github.com/amazon-science/chronos-forecasting.git
5
  matplotlib
 
 
 
3
  torch
4
  git+https://github.com/amazon-science/chronos-forecasting.git
5
  matplotlib
6
+ pandas
7
+ pyarrow