Spaces:
Sleeping
Sleeping
Update Gradio app with multiple files
Browse files
utils.py
CHANGED
|
@@ -160,7 +160,8 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
|
| 160 |
from chronos import BaseChronosPipeline
|
| 161 |
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
| 162 |
with torch.no_grad():
|
| 163 |
-
|
|
|
|
| 164 |
forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
|
| 165 |
if forecast_np.ndim > 1:
|
| 166 |
mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
|
|
|
|
| 160 |
from chronos import BaseChronosPipeline
|
| 161 |
pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
|
| 162 |
with torch.no_grad():
|
| 163 |
+
# Fix: Use context_tensor instead of context
|
| 164 |
+
forecast = pipeline.predict(context_tensor=torch.tensor(prices), prediction_length=prediction_days)
|
| 165 |
forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
|
| 166 |
if forecast_np.ndim > 1:
|
| 167 |
mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
|