omniverse1 commited on
Commit
c1b42c2
·
verified ·
1 Parent(s): 823c183

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +4 -1
utils.py CHANGED
@@ -265,7 +265,9 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
265
  from chronos import BaseChronosPipeline
266
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
267
  with torch.no_grad():
268
- forecast = pipeline.predict(context_tensor=torch.tensor(prices), prediction_length=prediction_days)
 
 
269
  forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
270
  if forecast_np.ndim > 1:
271
  mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
@@ -337,3 +339,4 @@ def create_technical_chart(data, indicators):
337
  fig.update_layout(title='Technical Indicators Overview', height=800, showlegend=False, hovermode='x unified')
338
  return fig
339
 
 
 
265
  from chronos import BaseChronosPipeline
266
  pipeline = BaseChronosPipeline.from_pretrained("amazon/chronos-bolt-base", device_map="auto")
267
  with torch.no_grad():
268
+ # FIX: Mengganti 'context_tensor' menjadi 'context'
269
+ forecast = pipeline.predict(context=torch.tensor(prices), prediction_length=prediction_days)
270
+
271
  forecast_np = forecast.squeeze().cpu().numpy() if isinstance(forecast, torch.Tensor) else np.array(forecast)
272
  if forecast_np.ndim > 1:
273
  mean_forecast = forecast_np.mean(axis=tuple(range(forecast_np.ndim - 1)))
 
339
  fig.update_layout(title='Technical Indicators Overview', height=800, showlegend=False, hovermode='x unified')
340
  return fig
341
 
342
+