"""Chat orchestrator for coordinating all schema translation components.""" import logging import time from datetime import datetime, timezone from typing import Dict, List, Optional, Any from schema_translator.agents import QueryUnderstandingAgent, SchemaAnalyzerAgent from schema_translator.config import Config from schema_translator.database_executor import DatabaseExecutor from schema_translator.knowledge_graph import SchemaKnowledgeGraph from schema_translator.models import ( HarmonizedResult, QueryFeedback, SemanticQueryPlan, ) from schema_translator.query_compiler import QueryCompiler from schema_translator.result_harmonizer import ResultHarmonizer from schema_translator.feedback_loop import FeedbackLoop from schema_translator.schema_drift_detector import SchemaDriftDetector # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class ChatOrchestrator: """Orchestrates all components for natural language query processing.""" def __init__( self, config: Optional[Config] = None, knowledge_graph: Optional[SchemaKnowledgeGraph] = None, use_llm: bool = True ): """Initialize the chat orchestrator. Args: config: Configuration object (creates new if None) knowledge_graph: Knowledge graph (loads from config if None) use_llm: Whether to use LLM agents (False for mock mode) """ self.config = config or Config() self.use_llm = use_llm # Load knowledge graph if knowledge_graph is None: logger.info("Loading knowledge graph...") self.knowledge_graph = SchemaKnowledgeGraph() self.knowledge_graph.load(self.config.knowledge_graph_path) else: self.knowledge_graph = knowledge_graph # Initialize components logger.info("Initializing components...") self.executor = DatabaseExecutor() self.compiler = QueryCompiler(self.knowledge_graph) self.result_harmonizer = ResultHarmonizer(self.knowledge_graph, self.executor) # Initialize agents (if using LLM) if self.use_llm: logger.info("Initializing LLM agents...") self.query_agent = QueryUnderstandingAgent( self.config.anthropic_api_key, self.knowledge_graph, self.config ) self.schema_agent = SchemaAnalyzerAgent( self.config.anthropic_api_key, self.knowledge_graph, self.config ) else: logger.info("Running in mock mode (no LLM)") self.query_agent = None self.schema_agent = None # Initialize query history self.query_history: List[Dict[str, Any]] = [] # Initialize feedback loop self.feedback_loop = FeedbackLoop() # Initialize drift detector self.drift_detector = SchemaDriftDetector( self.executor, self.knowledge_graph ) logger.info("ChatOrchestrator initialized successfully") def process_query( self, query_text: str, customer_ids: Optional[List[str]] = None, debug: bool = False ) -> Dict[str, Any]: """Process a natural language query end-to-end. Args: query_text: Natural language query customer_ids: Optional list of customer IDs to query (all if None) debug: Whether to include debug information Returns: Dictionary with results and metadata """ start_time = time.time() logger.info(f"Processing query: '{query_text}'") try: # Step 1: Validate query if not self._validate_query(query_text): raise ValueError("Invalid query: Query text is empty or too short") # Step 2: Parse query to semantic plan logger.info("Parsing query to semantic plan...") semantic_plan = self._parse_query(query_text) if debug: logger.info(f"Semantic plan: {semantic_plan}") # Step 3: Determine which customers to query # Priority: explicit customer_ids parameter > target_customers from query > all customers target_customers = customer_ids if target_customers is None and semantic_plan.target_customers: target_customers = semantic_plan.target_customers logger.info(f"Extracted target customers from query: {target_customers}") # Step 4: Execute query across customers logger.info(f"Executing query across {len(target_customers) if target_customers else 'all'} customers...") result = self.result_harmonizer.execute_across_customers( semantic_plan, customer_ids=target_customers ) # Step 5: Calculate total execution time total_time_ms = (time.time() - start_time) * 1000 logger.info( f"Query completed: {result.total_count} rows, " f"{result.success_rate:.1f}% success rate, " f"{total_time_ms:.2f}ms" ) # Step 6: Add to history self._add_to_history( query_text=query_text, semantic_plan=semantic_plan, result=result, execution_time_ms=total_time_ms, error=None ) # Step 7: Build response response = { "success": True, "query_text": query_text, "semantic_plan": semantic_plan if debug else None, "result": result, "execution_time_ms": total_time_ms, "error": None } if debug: response["debug"] = self._build_debug_info( semantic_plan, result, customer_ids ) return response except Exception as e: error_msg = str(e) total_time_ms = (time.time() - start_time) * 1000 logger.error(f"Query failed: {error_msg}", exc_info=True) # Add failed query to history self._add_to_history( query_text=query_text, semantic_plan=None, result=None, execution_time_ms=total_time_ms, error=error_msg ) return { "success": False, "query_text": query_text, "semantic_plan": None, "result": None, "execution_time_ms": total_time_ms, "error": error_msg } def _validate_query(self, query_text: str) -> bool: """Validate a query before processing. Args: query_text: Query text to validate Returns: True if valid, False otherwise """ if not query_text or not query_text.strip(): return False if len(query_text.strip()) < 3: return False return True def _parse_query(self, query_text: str) -> SemanticQueryPlan: """Parse natural language query to semantic plan. Args: query_text: Natural language query Returns: SemanticQueryPlan """ if self.use_llm and self.query_agent: # Use LLM agent return self.query_agent.understand_query(query_text) else: # Mock mode: create a simple plan from schema_translator.models import QueryIntent return SemanticQueryPlan( intent=QueryIntent.FIND_CONTRACTS, projections=["contract_identifier", "contract_status", "contract_value"], filters=[], aggregations=[], limit=10 ) def _build_debug_info( self, semantic_plan: SemanticQueryPlan, result: HarmonizedResult, customer_ids: Optional[List[str]] ) -> Dict[str, Any]: """Build debug information for a query. Args: semantic_plan: Semantic query plan result: Execution result customer_ids: Customer IDs queried Returns: Debug information dictionary """ debug_info = { "semantic_plan": { "intent": semantic_plan.intent.value if hasattr(semantic_plan.intent, 'value') else str(semantic_plan.intent), "projections": semantic_plan.projections, "filters": [ { "concept": f.concept, "operator": f.operator.value if hasattr(f.operator, 'value') else str(f.operator), "value": f.value } for f in (semantic_plan.filters or []) ], "aggregations": [ { "function": a.function, "concept": a.concept } for a in (semantic_plan.aggregations or []) ], "limit": semantic_plan.limit }, "customers": { "queried": result.customers_queried, "succeeded": result.customers_succeeded, "failed": result.customers_failed }, "sql_queries": {} } # Add SQL for each customer target_customers = customer_ids if customer_ids else result.customers_queried for customer_id in target_customers[:3]: # Limit to 3 for brevity try: sql = self.compiler.compile_for_customer(semantic_plan, customer_id) debug_info["sql_queries"][customer_id] = sql except Exception as e: debug_info["sql_queries"][customer_id] = f"Error: {e}" return debug_info def get_query_history(self, n: int = 10) -> List[Dict[str, Any]]: """Get recent query history. Args: n: Number of recent queries to return Returns: List of query records """ return self.query_history[-n:] if self.query_history else [] def get_failed_queries(self) -> List[Dict[str, Any]]: """Get all failed queries from history. Returns: List of failed query records """ return [q for q in self.query_history if not q["success"]] def get_statistics(self) -> Dict[str, Any]: """Get query execution statistics. Returns: Statistics dictionary """ stats = self._get_query_statistics() stats["knowledge_graph"] = self.knowledge_graph.get_stats() return stats def submit_feedback( self, query_text: str, semantic_plan: SemanticQueryPlan, feedback_type: str, feedback_text: Optional[str] = None ) -> QueryFeedback: """Submit feedback on a query result. Args: query_text: Original query text semantic_plan: Semantic query plan used feedback_type: Type of feedback (incorrect, missing, good) feedback_text: Optional feedback comment Returns: QueryFeedback object """ # Submit to feedback loop feedback = self.feedback_loop.submit_feedback( query_text=query_text, semantic_plan=semantic_plan, feedback_type=feedback_type, feedback_text=feedback_text ) logger.info(f"Feedback received: {feedback_type} for query '{query_text}'") return feedback def explain_query(self, query_text: str) -> Dict[str, Any]: """Explain how a query will be processed without executing it. Args: query_text: Natural language query Returns: Explanation dictionary """ try: # Parse to semantic plan semantic_plan = self._parse_query(query_text) # Get human-readable explanation if self.use_llm and self.query_agent: explanation = self.query_agent.explain_query_plan(semantic_plan) else: explanation = f"Will find contracts with projections: {semantic_plan.projections}" # Get sample SQL for a few customers sample_sql = {} for customer_id in ["customer_a", "customer_b", "customer_c"]: try: sql = self.compiler.compile_for_customer(semantic_plan, customer_id) sample_sql[customer_id] = sql except Exception as e: sample_sql[customer_id] = f"Error: {e}" return { "success": True, "query_text": query_text, "explanation": explanation, "semantic_plan": semantic_plan, "sample_sql": sample_sql } except Exception as e: return { "success": False, "query_text": query_text, "error": str(e) } def list_available_customers(self) -> List[str]: """Get list of available customer IDs. Returns: List of customer IDs """ db_dir = self.config.database_dir db_files = list(db_dir.glob("customer_*.db")) return sorted([f.stem for f in db_files]) def get_customer_info(self, customer_id: str) -> Dict[str, Any]: """Get information about a specific customer. Args: customer_id: Customer ID Returns: Customer information dictionary """ try: # Get table info from database table_info = self.executor.get_table_info(customer_id) # Get concept mappings concepts = {} for concept_id in self.knowledge_graph.concepts.keys(): mapping = self.knowledge_graph.get_mapping(concept_id, customer_id) if mapping: concepts[concept_id] = { "table": mapping.table_name, "column": mapping.column_name, "type": mapping.data_type, "semantic_type": str(mapping.semantic_type), "transformation": mapping.transformation } # Get row count # Get the primary table name from table_info primary_table = None if table_info: primary_table = list(table_info.keys())[0] if table_info else None row_count = 0 if primary_table: row_count = self.executor.count_rows(customer_id, primary_table) return { "customer_id": customer_id, "tables": table_info, "concepts": concepts, "total_rows": row_count, "available": True } except Exception as e: return { "customer_id": customer_id, "error": str(e), "available": False } def get_feedback_insights(self) -> Dict[str, Any]: """Get insights from user feedback. Returns: Feedback analysis and recommendations """ return self.feedback_loop.get_improvement_recommendations() def check_schema_drift( self, customer_ids: Optional[List[str]] = None ) -> Dict[str, Any]: """Check for schema drift in customer databases. Args: customer_ids: Optional list of customers to check (all if None) Returns: Dictionary of drift information """ if customer_ids: drifts = {} for customer_id in customer_ids: drift_list = self.drift_detector.detect_drift(customer_id) if drift_list: drifts[customer_id] = [d.to_dict() for d in drift_list] return drifts else: # Check all customers all_drifts = self.drift_detector.check_all_customers() return { customer_id: [d.to_dict() for d in drifts] for customer_id, drifts in all_drifts.items() } def get_system_health(self) -> Dict[str, Any]: """Get overall system health report. Returns: Comprehensive health report including feedback and drift """ # Get feedback insights feedback_insights = self.get_feedback_insights() # Get drift summary drift_summary = self.drift_detector.get_drift_summary() # Get query statistics query_stats = self.get_statistics() # Determine overall health health_score = 100 issues = [] # Check query success rate if query_stats.get("success_rate", 0) < 80: health_score -= 20 issues.append("Query success rate below 80%") # Check for critical drifts if drift_summary.get("critical_drifts"): health_score -= 30 issues.append(f"{len(drift_summary['critical_drifts'])} critical schema drifts detected") # Check feedback health if feedback_insights.get("overall_health") == "needs_improvement": health_score -= 15 issues.append("User feedback indicates issues") health_status = "excellent" if health_score >= 90 else \ "good" if health_score >= 70 else \ "fair" if health_score >= 50 else "poor" return { "health_status": health_status, "health_score": health_score, "issues": issues, "query_statistics": query_stats, "feedback_insights": feedback_insights, "drift_summary": drift_summary, "timestamp": datetime.now(timezone.utc).isoformat() } # Private methods for query history management (formerly QueryHistory class) def _add_to_history( self, query_text: str, semantic_plan: Optional[SemanticQueryPlan], result: Optional[HarmonizedResult], execution_time_ms: float, error: Optional[str] = None ): """Add a query to history. Args: query_text: Original natural language query semantic_plan: Parsed semantic query plan result: Query execution result execution_time_ms: Total execution time error: Error message if query failed """ self.query_history.append({ "timestamp": datetime.now(timezone.utc), "query_text": query_text, "semantic_plan": semantic_plan, "result": result, "execution_time_ms": execution_time_ms, "error": error, "success": error is None }) def _get_query_statistics(self) -> Dict[str, Any]: """Get query execution statistics. Returns: Statistics dictionary """ if not self.query_history: return { "total_queries": 0, "successful_queries": 0, "failed_queries": 0, "success_rate": 0.0, "average_execution_time_ms": 0.0 } successful = [q for q in self.query_history if q["success"]] total_time = sum(q["execution_time_ms"] for q in self.query_history) return { "total_queries": len(self.query_history), "successful_queries": len(successful), "failed_queries": len(self.query_history) - len(successful), "success_rate": len(successful) / len(self.query_history) * 100, "average_execution_time_ms": total_time / len(self.query_history) }