omniverse1 commited on
Commit
37ef262
·
verified ·
1 Parent(s): 8dd55d0

Update Gradio app with multiple files

Browse files
Files changed (1) hide show
  1. utils.py +2 -1
utils.py CHANGED
@@ -160,7 +160,8 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
160
  from chronos import BaseChronosPipeline
161
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
162
  with torch.no_grad():
163
- forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
 
164
  forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
165
  if forecast_np.ndim > 1:
166
  mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
 
160
  from chronos import BaseChronosPipeline
161
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
162
  with torch.no_grad():
163
+ # Fix: Use context_tensor instead of context
164
+ forecast = pipeline.predict(context_tensor=torch.tensor(prices), prediction_length=prediction_days)
165
  forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
166
  if forecast_np.ndim > 1:
167
  mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))