Spaces:
Running
Running
Commit
·
32511bd
1
Parent(s):
1a671f5
Upgrade to Chronos-2 with DataFrame API support
Browse files- app.py +19 -10
- requirements.txt +2 -0
app.py
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import torch
|
| 4 |
-
from chronos import
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
# Load the Chronos Pipeline model
|
| 9 |
@st.cache_resource
|
| 10 |
def load_pipeline():
|
| 11 |
-
pipeline =
|
| 12 |
"amazon/chronos-2",
|
| 13 |
device_map="cpu", # Change to CPU
|
| 14 |
-
|
| 15 |
)
|
| 16 |
return pipeline
|
| 17 |
|
|
@@ -53,19 +53,28 @@ prediction_length = st.slider("Select Forecast Horizon (Months)", min_value=1, m
|
|
| 53 |
|
| 54 |
# If data is valid, perform the forecast
|
| 55 |
if time_series_data:
|
| 56 |
-
#
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
-
# Make the forecast
|
| 60 |
-
|
| 61 |
-
|
| 62 |
prediction_length=prediction_length,
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
|
| 66 |
# Prepare forecast data for plotting
|
| 67 |
forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length)
|
| 68 |
-
|
|
|
|
|
|
|
| 69 |
|
| 70 |
# Plot the historical and forecasted data
|
| 71 |
plt.figure(figsize=(8, 4))
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import pandas as pd
|
| 3 |
import torch
|
| 4 |
+
from chronos import Chronos2Pipeline
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
# Load the Chronos Pipeline model
|
| 9 |
@st.cache_resource
|
| 10 |
def load_pipeline():
|
| 11 |
+
pipeline = Chronos2Pipeline.from_pretrained(
|
| 12 |
"amazon/chronos-2",
|
| 13 |
device_map="cpu", # Change to CPU
|
| 14 |
+
dtype=torch.float32, # Use float32 for CPU
|
| 15 |
)
|
| 16 |
return pipeline
|
| 17 |
|
|
|
|
| 53 |
|
| 54 |
# If data is valid, perform the forecast
|
| 55 |
if time_series_data:
|
| 56 |
+
# Create a DataFrame for Chronos-2
|
| 57 |
+
context_df = pd.DataFrame({
|
| 58 |
+
'timestamp': pd.date_range(start='2020-01-01', periods=len(time_series_data), freq='ME'),
|
| 59 |
+
'target': time_series_data,
|
| 60 |
+
'id': 'series_1'
|
| 61 |
+
})
|
| 62 |
|
| 63 |
+
# Make the forecast using Chronos-2 API
|
| 64 |
+
pred_df = pipeline.predict_df(
|
| 65 |
+
context_df,
|
| 66 |
prediction_length=prediction_length,
|
| 67 |
+
quantile_levels=[0.1, 0.5, 0.9],
|
| 68 |
+
id_column="id",
|
| 69 |
+
timestamp_column="timestamp",
|
| 70 |
+
target="target",
|
| 71 |
)
|
| 72 |
|
| 73 |
# Prepare forecast data for plotting
|
| 74 |
forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length)
|
| 75 |
+
median = pred_df["predictions"].values
|
| 76 |
+
low = pred_df["0.1"].values
|
| 77 |
+
high = pred_df["0.9"].values
|
| 78 |
|
| 79 |
# Plot the historical and forecasted data
|
| 80 |
plt.figure(figsize=(8, 4))
|
requirements.txt
CHANGED
|
@@ -3,3 +3,5 @@ transformers
|
|
| 3 |
torch
|
| 4 |
git+https://github.com/amazon-science/chronos-forecasting.git
|
| 5 |
matplotlib
|
|
|
|
|
|
|
|
|
| 3 |
torch
|
| 4 |
git+https://github.com/amazon-science/chronos-forecasting.git
|
| 5 |
matplotlib
|
| 6 |
+
pandas
|
| 7 |
+
pyarrow
|