Spaces:
Running
Running
| # auto_causal/config.py | |
| """Central configuration for AutoCausal, including LLM client setup.""" | |
| import os | |
| import logging | |
| from typing import Optional | |
| # Langchain imports | |
| from langchain_core.language_models import BaseChatModel | |
| from langchain_openai import ChatOpenAI # Default | |
| from langchain_anthropic import ChatAnthropic # Example | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| # Add other providers if needed, e.g.: | |
| # from langchain_community.chat_models import ChatOllama | |
| from dotenv import load_dotenv | |
| from langchain_deepseek import ChatDeepSeek | |
| # Create a disk-backed SQLite cache: | |
| # Import Together provider | |
| from langchain_together import ChatTogether | |
| logger = logging.getLogger(__name__) | |
| # Load .env file when this module is loaded | |
| load_dotenv() | |
| def get_llm_client(provider: Optional[str] = None, model_name: Optional[str] = None, **kwargs) -> BaseChatModel: | |
| """Initializes and returns the chosen LLM client based on provider. | |
| Reads provider, model, and API keys from environment variables if not passed directly. | |
| Defaults to OpenAI GPT-4o-mini if no provider/model specified. | |
| """ | |
| # Prioritize arguments, then environment variables, then defaults | |
| provider = provider or os.getenv("LLM_PROVIDER", "openai") | |
| provider = provider.lower() | |
| # Default model depends on provider | |
| default_models = { | |
| "openai": "gpt-4o-mini", | |
| "anthropic": "claude-3-5-sonnet-latest", | |
| "together": "deepseek-ai/DeepSeek-V3", # Default Together model | |
| "gemini" : "gemini-2.5-flash", | |
| "deepseek" : "deepseek-chat" | |
| } | |
| model_name = model_name or os.getenv("LLM_MODEL", default_models.get(provider, default_models["openai"])) | |
| api_key = None | |
| if model_name not in ['o3-mini', 'o3', 'o4-mini']: | |
| kwargs.setdefault("temperature", 0) # Default temperature if not provided | |
| logger.info(f"Initializing LLM client: Provider='{provider}', Model='{model_name}'") | |
| try: | |
| if provider == "openai": | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY not found in environment.") | |
| return ChatOpenAI(model=model_name, api_key=api_key, **kwargs) | |
| elif provider == "anthropic": | |
| api_key = os.getenv("ANTHROPIC_API_KEY") | |
| if not api_key: | |
| raise ValueError("ANTHROPIC_API_KEY not found in environment.") | |
| return ChatAnthropic(model=model_name, api_key=api_key, **kwargs, streaming=False) | |
| elif provider == "together": | |
| api_key = os.getenv("TOGETHER_API_KEY") | |
| if not api_key: | |
| raise ValueError("TOGETHER_API_KEY not found in environment.") | |
| return ChatTogether(model=model_name, api_key=api_key, **kwargs) | |
| elif provider == "gemini": | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GEMINI_API_KEY not found in environment.") | |
| return ChatGoogleGenerativeAI(model=model_name, **kwargs, function_calling="auto") | |
| elif provider == "deepseek": | |
| api_key = os.getenv("DEEPSEEK_API_KEY") | |
| if not api_key: | |
| raise ValueError("DEEPSEEK_API_KEY not found in environment.") | |
| return ChatDeepSeek(model=model_name, **kwargs) | |
| # Example for Ollama (ensure langchain_community is installed) | |
| # elif provider == "ollama": | |
| # try: | |
| # from langchain_community.chat_models import ChatOllama | |
| # return ChatOllama(model=model_name, **kwargs) | |
| # except ImportError: | |
| # raise ValueError("langchain_community needed for Ollama. Run `pip install langchain-community`") | |
| else: | |
| raise ValueError(f"Unsupported LLM provider: {provider}") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize LLM (Provider: {provider}, Model: {model_name}): {e}") | |
| raise # Re-raise the exception |