Spaces:
Paused
Paused
| """Chat orchestrator for coordinating all schema translation components.""" | |
| import logging | |
| import time | |
| from datetime import datetime, timezone | |
| from typing import Dict, List, Optional, Any | |
| from schema_translator.agents import QueryUnderstandingAgent, SchemaAnalyzerAgent | |
| from schema_translator.config import Config | |
| from schema_translator.database_executor import DatabaseExecutor | |
| from schema_translator.knowledge_graph import SchemaKnowledgeGraph | |
| from schema_translator.models import ( | |
| HarmonizedResult, | |
| QueryFeedback, | |
| SemanticQueryPlan, | |
| ) | |
| from schema_translator.query_compiler import QueryCompiler | |
| from schema_translator.result_harmonizer import ResultHarmonizer | |
| from schema_translator.feedback_loop import FeedbackLoop | |
| from schema_translator.schema_drift_detector import SchemaDriftDetector | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class ChatOrchestrator: | |
| """Orchestrates all components for natural language query processing.""" | |
| def __init__( | |
| self, | |
| config: Optional[Config] = None, | |
| knowledge_graph: Optional[SchemaKnowledgeGraph] = None, | |
| use_llm: bool = True | |
| ): | |
| """Initialize the chat orchestrator. | |
| Args: | |
| config: Configuration object (creates new if None) | |
| knowledge_graph: Knowledge graph (loads from config if None) | |
| use_llm: Whether to use LLM agents (False for mock mode) | |
| """ | |
| self.config = config or Config() | |
| self.use_llm = use_llm | |
| # Load knowledge graph | |
| if knowledge_graph is None: | |
| logger.info("Loading knowledge graph...") | |
| self.knowledge_graph = SchemaKnowledgeGraph() | |
| self.knowledge_graph.load(self.config.knowledge_graph_path) | |
| else: | |
| self.knowledge_graph = knowledge_graph | |
| # Initialize components | |
| logger.info("Initializing components...") | |
| self.executor = DatabaseExecutor() | |
| self.compiler = QueryCompiler(self.knowledge_graph) | |
| self.result_harmonizer = ResultHarmonizer(self.knowledge_graph, self.executor) | |
| # Initialize agents (if using LLM) | |
| if self.use_llm: | |
| logger.info("Initializing LLM agents...") | |
| self.query_agent = QueryUnderstandingAgent( | |
| self.config.anthropic_api_key, | |
| self.knowledge_graph, | |
| self.config | |
| ) | |
| self.schema_agent = SchemaAnalyzerAgent( | |
| self.config.anthropic_api_key, | |
| self.knowledge_graph, | |
| self.config | |
| ) | |
| else: | |
| logger.info("Running in mock mode (no LLM)") | |
| self.query_agent = None | |
| self.schema_agent = None | |
| # Initialize query history | |
| self.query_history: List[Dict[str, Any]] = [] | |
| # Initialize feedback loop | |
| self.feedback_loop = FeedbackLoop() | |
| # Initialize drift detector | |
| self.drift_detector = SchemaDriftDetector( | |
| self.executor, | |
| self.knowledge_graph | |
| ) | |
| logger.info("ChatOrchestrator initialized successfully") | |
| def process_query( | |
| self, | |
| query_text: str, | |
| customer_ids: Optional[List[str]] = None, | |
| debug: bool = False | |
| ) -> Dict[str, Any]: | |
| """Process a natural language query end-to-end. | |
| Args: | |
| query_text: Natural language query | |
| customer_ids: Optional list of customer IDs to query (all if None) | |
| debug: Whether to include debug information | |
| Returns: | |
| Dictionary with results and metadata | |
| """ | |
| start_time = time.time() | |
| logger.info(f"Processing query: '{query_text}'") | |
| try: | |
| # Step 1: Validate query | |
| if not self._validate_query(query_text): | |
| raise ValueError("Invalid query: Query text is empty or too short") | |
| # Step 2: Parse query to semantic plan | |
| logger.info("Parsing query to semantic plan...") | |
| semantic_plan = self._parse_query(query_text) | |
| if debug: | |
| logger.info(f"Semantic plan: {semantic_plan}") | |
| # Step 3: Determine which customers to query | |
| # Priority: explicit customer_ids parameter > target_customers from query > all customers | |
| target_customers = customer_ids | |
| if target_customers is None and semantic_plan.target_customers: | |
| target_customers = semantic_plan.target_customers | |
| logger.info(f"Extracted target customers from query: {target_customers}") | |
| # Step 4: Execute query across customers | |
| logger.info(f"Executing query across {len(target_customers) if target_customers else 'all'} customers...") | |
| result = self.result_harmonizer.execute_across_customers( | |
| semantic_plan, | |
| customer_ids=target_customers | |
| ) | |
| # Step 5: Calculate total execution time | |
| total_time_ms = (time.time() - start_time) * 1000 | |
| logger.info( | |
| f"Query completed: {result.total_count} rows, " | |
| f"{result.success_rate:.1f}% success rate, " | |
| f"{total_time_ms:.2f}ms" | |
| ) | |
| # Step 6: Add to history | |
| self._add_to_history( | |
| query_text=query_text, | |
| semantic_plan=semantic_plan, | |
| result=result, | |
| execution_time_ms=total_time_ms, | |
| error=None | |
| ) | |
| # Step 7: Build response | |
| response = { | |
| "success": True, | |
| "query_text": query_text, | |
| "semantic_plan": semantic_plan if debug else None, | |
| "result": result, | |
| "execution_time_ms": total_time_ms, | |
| "error": None | |
| } | |
| if debug: | |
| response["debug"] = self._build_debug_info( | |
| semantic_plan, | |
| result, | |
| customer_ids | |
| ) | |
| return response | |
| except Exception as e: | |
| error_msg = str(e) | |
| total_time_ms = (time.time() - start_time) * 1000 | |
| logger.error(f"Query failed: {error_msg}", exc_info=True) | |
| # Add failed query to history | |
| self._add_to_history( | |
| query_text=query_text, | |
| semantic_plan=None, | |
| result=None, | |
| execution_time_ms=total_time_ms, | |
| error=error_msg | |
| ) | |
| return { | |
| "success": False, | |
| "query_text": query_text, | |
| "semantic_plan": None, | |
| "result": None, | |
| "execution_time_ms": total_time_ms, | |
| "error": error_msg | |
| } | |
| def _validate_query(self, query_text: str) -> bool: | |
| """Validate a query before processing. | |
| Args: | |
| query_text: Query text to validate | |
| Returns: | |
| True if valid, False otherwise | |
| """ | |
| if not query_text or not query_text.strip(): | |
| return False | |
| if len(query_text.strip()) < 3: | |
| return False | |
| return True | |
| def _parse_query(self, query_text: str) -> SemanticQueryPlan: | |
| """Parse natural language query to semantic plan. | |
| Args: | |
| query_text: Natural language query | |
| Returns: | |
| SemanticQueryPlan | |
| """ | |
| if self.use_llm and self.query_agent: | |
| # Use LLM agent | |
| return self.query_agent.understand_query(query_text) | |
| else: | |
| # Mock mode: create a simple plan | |
| from schema_translator.models import QueryIntent | |
| return SemanticQueryPlan( | |
| intent=QueryIntent.FIND_CONTRACTS, | |
| projections=["contract_identifier", "contract_status", "contract_value"], | |
| filters=[], | |
| aggregations=[], | |
| limit=10 | |
| ) | |
| def _build_debug_info( | |
| self, | |
| semantic_plan: SemanticQueryPlan, | |
| result: HarmonizedResult, | |
| customer_ids: Optional[List[str]] | |
| ) -> Dict[str, Any]: | |
| """Build debug information for a query. | |
| Args: | |
| semantic_plan: Semantic query plan | |
| result: Execution result | |
| customer_ids: Customer IDs queried | |
| Returns: | |
| Debug information dictionary | |
| """ | |
| debug_info = { | |
| "semantic_plan": { | |
| "intent": semantic_plan.intent.value if hasattr(semantic_plan.intent, 'value') else str(semantic_plan.intent), | |
| "projections": semantic_plan.projections, | |
| "filters": [ | |
| { | |
| "concept": f.concept, | |
| "operator": f.operator.value if hasattr(f.operator, 'value') else str(f.operator), | |
| "value": f.value | |
| } | |
| for f in (semantic_plan.filters or []) | |
| ], | |
| "aggregations": [ | |
| { | |
| "function": a.function, | |
| "concept": a.concept | |
| } | |
| for a in (semantic_plan.aggregations or []) | |
| ], | |
| "limit": semantic_plan.limit | |
| }, | |
| "customers": { | |
| "queried": result.customers_queried, | |
| "succeeded": result.customers_succeeded, | |
| "failed": result.customers_failed | |
| }, | |
| "sql_queries": {} | |
| } | |
| # Add SQL for each customer | |
| target_customers = customer_ids if customer_ids else result.customers_queried | |
| for customer_id in target_customers[:3]: # Limit to 3 for brevity | |
| try: | |
| sql = self.compiler.compile_for_customer(semantic_plan, customer_id) | |
| debug_info["sql_queries"][customer_id] = sql | |
| except Exception as e: | |
| debug_info["sql_queries"][customer_id] = f"Error: {e}" | |
| return debug_info | |
| def get_query_history(self, n: int = 10) -> List[Dict[str, Any]]: | |
| """Get recent query history. | |
| Args: | |
| n: Number of recent queries to return | |
| Returns: | |
| List of query records | |
| """ | |
| return self.query_history[-n:] if self.query_history else [] | |
| def get_failed_queries(self) -> List[Dict[str, Any]]: | |
| """Get all failed queries from history. | |
| Returns: | |
| List of failed query records | |
| """ | |
| return [q for q in self.query_history if not q["success"]] | |
| def get_statistics(self) -> Dict[str, Any]: | |
| """Get query execution statistics. | |
| Returns: | |
| Statistics dictionary | |
| """ | |
| stats = self._get_query_statistics() | |
| stats["knowledge_graph"] = self.knowledge_graph.get_stats() | |
| return stats | |
| def submit_feedback( | |
| self, | |
| query_text: str, | |
| semantic_plan: SemanticQueryPlan, | |
| feedback_type: str, | |
| feedback_text: Optional[str] = None | |
| ) -> QueryFeedback: | |
| """Submit feedback on a query result. | |
| Args: | |
| query_text: Original query text | |
| semantic_plan: Semantic query plan used | |
| feedback_type: Type of feedback (incorrect, missing, good) | |
| feedback_text: Optional feedback comment | |
| Returns: | |
| QueryFeedback object | |
| """ | |
| # Submit to feedback loop | |
| feedback = self.feedback_loop.submit_feedback( | |
| query_text=query_text, | |
| semantic_plan=semantic_plan, | |
| feedback_type=feedback_type, | |
| feedback_text=feedback_text | |
| ) | |
| logger.info(f"Feedback received: {feedback_type} for query '{query_text}'") | |
| return feedback | |
| def explain_query(self, query_text: str) -> Dict[str, Any]: | |
| """Explain how a query will be processed without executing it. | |
| Args: | |
| query_text: Natural language query | |
| Returns: | |
| Explanation dictionary | |
| """ | |
| try: | |
| # Parse to semantic plan | |
| semantic_plan = self._parse_query(query_text) | |
| # Get human-readable explanation | |
| if self.use_llm and self.query_agent: | |
| explanation = self.query_agent.explain_query_plan(semantic_plan) | |
| else: | |
| explanation = f"Will find contracts with projections: {semantic_plan.projections}" | |
| # Get sample SQL for a few customers | |
| sample_sql = {} | |
| for customer_id in ["customer_a", "customer_b", "customer_c"]: | |
| try: | |
| sql = self.compiler.compile_for_customer(semantic_plan, customer_id) | |
| sample_sql[customer_id] = sql | |
| except Exception as e: | |
| sample_sql[customer_id] = f"Error: {e}" | |
| return { | |
| "success": True, | |
| "query_text": query_text, | |
| "explanation": explanation, | |
| "semantic_plan": semantic_plan, | |
| "sample_sql": sample_sql | |
| } | |
| except Exception as e: | |
| return { | |
| "success": False, | |
| "query_text": query_text, | |
| "error": str(e) | |
| } | |
| def list_available_customers(self) -> List[str]: | |
| """Get list of available customer IDs. | |
| Returns: | |
| List of customer IDs | |
| """ | |
| db_dir = self.config.database_dir | |
| db_files = list(db_dir.glob("customer_*.db")) | |
| return sorted([f.stem for f in db_files]) | |
| def get_customer_info(self, customer_id: str) -> Dict[str, Any]: | |
| """Get information about a specific customer. | |
| Args: | |
| customer_id: Customer ID | |
| Returns: | |
| Customer information dictionary | |
| """ | |
| try: | |
| # Get table info from database | |
| table_info = self.executor.get_table_info(customer_id) | |
| # Get concept mappings | |
| concepts = {} | |
| for concept_id in self.knowledge_graph.concepts.keys(): | |
| mapping = self.knowledge_graph.get_mapping(concept_id, customer_id) | |
| if mapping: | |
| concepts[concept_id] = { | |
| "table": mapping.table_name, | |
| "column": mapping.column_name, | |
| "type": mapping.data_type, | |
| "semantic_type": str(mapping.semantic_type), | |
| "transformation": mapping.transformation | |
| } | |
| # Get row count | |
| # Get the primary table name from table_info | |
| primary_table = None | |
| if table_info: | |
| primary_table = list(table_info.keys())[0] if table_info else None | |
| row_count = 0 | |
| if primary_table: | |
| row_count = self.executor.count_rows(customer_id, primary_table) | |
| return { | |
| "customer_id": customer_id, | |
| "tables": table_info, | |
| "concepts": concepts, | |
| "total_rows": row_count, | |
| "available": True | |
| } | |
| except Exception as e: | |
| return { | |
| "customer_id": customer_id, | |
| "error": str(e), | |
| "available": False | |
| } | |
| def get_feedback_insights(self) -> Dict[str, Any]: | |
| """Get insights from user feedback. | |
| Returns: | |
| Feedback analysis and recommendations | |
| """ | |
| return self.feedback_loop.get_improvement_recommendations() | |
| def check_schema_drift( | |
| self, | |
| customer_ids: Optional[List[str]] = None | |
| ) -> Dict[str, Any]: | |
| """Check for schema drift in customer databases. | |
| Args: | |
| customer_ids: Optional list of customers to check (all if None) | |
| Returns: | |
| Dictionary of drift information | |
| """ | |
| if customer_ids: | |
| drifts = {} | |
| for customer_id in customer_ids: | |
| drift_list = self.drift_detector.detect_drift(customer_id) | |
| if drift_list: | |
| drifts[customer_id] = [d.to_dict() for d in drift_list] | |
| return drifts | |
| else: | |
| # Check all customers | |
| all_drifts = self.drift_detector.check_all_customers() | |
| return { | |
| customer_id: [d.to_dict() for d in drifts] | |
| for customer_id, drifts in all_drifts.items() | |
| } | |
| def get_system_health(self) -> Dict[str, Any]: | |
| """Get overall system health report. | |
| Returns: | |
| Comprehensive health report including feedback and drift | |
| """ | |
| # Get feedback insights | |
| feedback_insights = self.get_feedback_insights() | |
| # Get drift summary | |
| drift_summary = self.drift_detector.get_drift_summary() | |
| # Get query statistics | |
| query_stats = self.get_statistics() | |
| # Determine overall health | |
| health_score = 100 | |
| issues = [] | |
| # Check query success rate | |
| if query_stats.get("success_rate", 0) < 80: | |
| health_score -= 20 | |
| issues.append("Query success rate below 80%") | |
| # Check for critical drifts | |
| if drift_summary.get("critical_drifts"): | |
| health_score -= 30 | |
| issues.append(f"{len(drift_summary['critical_drifts'])} critical schema drifts detected") | |
| # Check feedback health | |
| if feedback_insights.get("overall_health") == "needs_improvement": | |
| health_score -= 15 | |
| issues.append("User feedback indicates issues") | |
| health_status = "excellent" if health_score >= 90 else \ | |
| "good" if health_score >= 70 else \ | |
| "fair" if health_score >= 50 else "poor" | |
| return { | |
| "health_status": health_status, | |
| "health_score": health_score, | |
| "issues": issues, | |
| "query_statistics": query_stats, | |
| "feedback_insights": feedback_insights, | |
| "drift_summary": drift_summary, | |
| "timestamp": datetime.now(timezone.utc).isoformat() | |
| } | |
| # Private methods for query history management (formerly QueryHistory class) | |
| def _add_to_history( | |
| self, | |
| query_text: str, | |
| semantic_plan: Optional[SemanticQueryPlan], | |
| result: Optional[HarmonizedResult], | |
| execution_time_ms: float, | |
| error: Optional[str] = None | |
| ): | |
| """Add a query to history. | |
| Args: | |
| query_text: Original natural language query | |
| semantic_plan: Parsed semantic query plan | |
| result: Query execution result | |
| execution_time_ms: Total execution time | |
| error: Error message if query failed | |
| """ | |
| self.query_history.append({ | |
| "timestamp": datetime.now(timezone.utc), | |
| "query_text": query_text, | |
| "semantic_plan": semantic_plan, | |
| "result": result, | |
| "execution_time_ms": execution_time_ms, | |
| "error": error, | |
| "success": error is None | |
| }) | |
| def _get_query_statistics(self) -> Dict[str, Any]: | |
| """Get query execution statistics. | |
| Returns: | |
| Statistics dictionary | |
| """ | |
| if not self.query_history: | |
| return { | |
| "total_queries": 0, | |
| "successful_queries": 0, | |
| "failed_queries": 0, | |
| "success_rate": 0.0, | |
| "average_execution_time_ms": 0.0 | |
| } | |
| successful = [q for q in self.query_history if q["success"]] | |
| total_time = sum(q["execution_time_ms"] for q in self.query_history) | |
| return { | |
| "total_queries": len(self.query_history), | |
| "successful_queries": len(successful), | |
| "failed_queries": len(self.query_history) - len(successful), | |
| "success_rate": len(successful) / len(self.query_history) * 100, | |
| "average_execution_time_ms": total_time / len(self.query_history) | |
| } | |