Spaces:
Paused
Paused
| """ | |
| Chainlit UI for Schema Translator | |
| This module provides a chat interface for querying customer databases | |
| using natural language queries that are automatically translated to SQL. | |
| """ | |
| import chainlit as cl | |
| from typing import Optional, List, Dict, Any | |
| import logging | |
| from datetime import datetime | |
| from schema_translator.orchestrator import ChatOrchestrator | |
| from schema_translator.config import Config | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Global orchestrator instance | |
| orchestrator: Optional[ChatOrchestrator] = None | |
| def format_result_table(result, limit: int = 10) -> str: | |
| """Format harmonized result as a markdown table. | |
| Args: | |
| result: HarmonizedResult object | |
| limit: Maximum number of rows to display (default 10) | |
| Returns: | |
| Markdown formatted table | |
| """ | |
| if not result.results: | |
| return "*No results found*" | |
| # Get column names from first row | |
| if not result.results: | |
| return "*No data*" | |
| first_row = result.results[0] | |
| all_columns = list(first_row.data.keys()) | |
| # Define preferred field order and nice names | |
| field_order = [ | |
| 'contract_identifier', | |
| 'customer_name', | |
| 'contract_name', | |
| 'industry_sector', | |
| 'contract_value', | |
| 'contract_status', | |
| 'contract_expiration' | |
| ] | |
| nice_names = { | |
| 'contract_identifier': 'Contract ID', | |
| 'contract_name': 'Contract Name', | |
| 'customer_name': 'Company', # Changed from Customer to Company | |
| 'contract_value': 'Value', | |
| 'contract_status': 'Status', | |
| 'contract_expiration': 'Expiration', | |
| 'contract_start': 'Start Date', | |
| 'industry_sector': 'Industry' | |
| } | |
| # Order columns: preferred order first, then any remaining | |
| columns = [] | |
| for field in field_order: | |
| if field in all_columns: | |
| columns.append(field) | |
| # Add any remaining columns not in preferred order | |
| for col in all_columns: | |
| if col not in columns: | |
| columns.append(col) | |
| # Build markdown table with better formatting | |
| lines = [] | |
| # Header with nicer column names | |
| display_cols = [nice_names.get(col, col) for col in columns] | |
| header = "| " + " | ".join(display_cols) + " |" | |
| separator = "|" + "|".join(["---" for _ in columns]) + "|" | |
| lines.append(header) | |
| lines.append(separator) | |
| # Rows (limit to specified number or total results, whichever is smaller) | |
| max_rows = min(limit, len(result.results)) | |
| for row in result.results[:max_rows]: | |
| values = [] | |
| for col in columns: | |
| val = row.data.get(col, "") | |
| # Format values for better display | |
| if val is None or val == "None": | |
| val = "β" | |
| elif isinstance(val, (int, float)) and col == 'contract_value': | |
| # Format currency values | |
| val = f"${val:,.0f}" | |
| else: | |
| val = str(val) | |
| values.append(val) | |
| line = "| " + " | ".join(values) + " |" | |
| lines.append(line) | |
| if len(result.results) > max_rows: | |
| lines.append(f"\n*Showing {max_rows} of {len(result.results)} rows. Use `/limit <number>` to adjust.*") | |
| return "\n".join(lines) | |
| def format_statistics(stats: Dict[str, Any]) -> str: | |
| """Format statistics as markdown. | |
| Args: | |
| stats: Statistics dictionary | |
| Returns: | |
| Formatted statistics | |
| """ | |
| lines = [ | |
| "### π Execution Statistics", | |
| f"- **Success Rate:** {stats['success_rate']:.1f}%", | |
| f"- **Total Rows:** {stats['total_rows']}", | |
| f"- **Customers Queried:** {len(stats['customers_queried'])}", | |
| f"- **Customers Succeeded:** {len(stats['customers_succeeded'])}", | |
| f"- **Execution Time:** {stats['execution_time_ms']:.2f}ms" | |
| ] | |
| if stats['customers_failed']: | |
| lines.append(f"- **Customers Failed:** {', '.join(stats['customers_failed'])}") | |
| return "\n".join(lines) | |
| def format_debug_info(debug: Dict[str, Any]) -> str: | |
| """Format debug information as markdown. | |
| Args: | |
| debug: Debug information dictionary | |
| Returns: | |
| Formatted debug info | |
| """ | |
| projections = debug['semantic_plan']['projections'] | |
| proj_list = ', '.join(f'`{p}`' for p in projections) if projections else "*(none - SELECT ALL)*" | |
| lines = [ | |
| "### π Debug Information", | |
| "", | |
| "**Semantic Plan:**", | |
| f"- Intent: `{debug['semantic_plan']['intent']}`", | |
| f"- Projections ({len(projections)}): {proj_list}", | |
| ] | |
| if debug['semantic_plan']['filters']: | |
| lines.append(f"- Filters: {len(debug['semantic_plan']['filters'])}") | |
| for f in debug['semantic_plan']['filters']: | |
| lines.append(f" - `{f['concept']}` {f['operator']} `{f['value']}`") | |
| if debug['semantic_plan']['aggregations']: | |
| lines.append(f"- Aggregations: {', '.join(debug['semantic_plan']['aggregations'])}") | |
| # Show actual columns returned | |
| if 'actual_columns' in debug: | |
| actual_cols = ', '.join(f'`{c}`' for c in debug['actual_columns']) | |
| lines.append(f"\n**Actual Columns Returned ({len(debug['actual_columns'])}):**") | |
| lines.append(actual_cols) | |
| # Show sample SQL for first customer | |
| lines.append("\n**Sample SQL (first customer):**") | |
| if debug['sql_queries']: | |
| first_customer = list(debug['sql_queries'].keys())[0] | |
| sql = debug['sql_queries'][first_customer] | |
| lines.append(f"```sql\n{sql}\n```") | |
| return "\n".join(lines) | |
| async def start(): | |
| """Initialize the chat session.""" | |
| global orchestrator | |
| # Initialize orchestrator | |
| try: | |
| config = Config() | |
| orchestrator = ChatOrchestrator(use_llm=bool(config.anthropic_api_key)) | |
| # Store in session | |
| cl.user_session.set("orchestrator", orchestrator) | |
| cl.user_session.set("debug_mode", False) | |
| cl.user_session.set("selected_customers", []) # Empty = all customers | |
| cl.user_session.set("result_limit", 10) # Default rows to display | |
| mode = "LLM mode" if orchestrator.use_llm else "Mock mode" | |
| # Send welcome message | |
| welcome_msg = f"""# π― Schema Translator Chat | |
| Welcome! I'm running in **{mode}**. | |
| I can help you query customer databases using natural language. Just ask me questions like: | |
| - "Show me all contracts" | |
| - "Find active contracts with value over 10000" | |
| - "Count contracts by status" | |
| - "List all customers" | |
| ### Available Commands: | |
| - `/customers` - List available customers | |
| - `/limit <number>` - Set max rows to display (default: 10) | |
| - `/debug on/off` - Toggle debug mode | |
| - `/stats` - Show query statistics | |
| - `/help` - Show this help message | |
| Try asking me a question! | |
| """ | |
| await cl.Message(content=welcome_msg).send() | |
| # Get available customers | |
| customers = orchestrator.list_available_customers() | |
| logger.info(f"Chat started with {len(customers)} customers available") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize orchestrator: {e}", exc_info=True) | |
| await cl.Message( | |
| content=f"β **Error:** Failed to initialize. Please check configuration.\n\n`{str(e)}`" | |
| ).send() | |
| async def main(message: cl.Message): | |
| """Handle incoming chat messages.""" | |
| global orchestrator | |
| orchestrator = cl.user_session.get("orchestrator") | |
| debug_mode = cl.user_session.get("debug_mode", False) | |
| selected_customers = cl.user_session.get("selected_customers", []) | |
| result_limit = cl.user_session.get("result_limit", 10) | |
| if not orchestrator: | |
| await cl.Message(content="β Orchestrator not initialized. Please refresh.").send() | |
| return | |
| query_text = message.content.strip() | |
| # Handle commands | |
| if query_text.startswith("/"): | |
| await handle_command(query_text, orchestrator, debug_mode) | |
| return | |
| # Validate query | |
| if not query_text or len(query_text) < 3: | |
| await cl.Message(content="β οΈ Please enter a valid query (at least 3 characters).").send() | |
| return | |
| # Show processing message | |
| processing_msg = cl.Message(content="π€ Processing your query...") | |
| await processing_msg.send() | |
| try: | |
| # Execute query | |
| customer_ids = selected_customers if selected_customers else None | |
| response = orchestrator.process_query( | |
| query_text, | |
| customer_ids=customer_ids, | |
| debug=debug_mode | |
| ) | |
| if response['success']: | |
| # Format successful response | |
| result = response['result'] | |
| # Build response message | |
| content_parts = [] | |
| # Show available fields | |
| if result.results: | |
| fields = list(result.results[0].data.keys()) | |
| field_names = { | |
| 'contract_identifier': 'Contract ID', | |
| 'contract_name': 'Contract Name', | |
| 'customer_name': 'Company', # Changed from Customer to Company | |
| 'contract_value': 'Value', | |
| 'contract_status': 'Status', | |
| 'contract_expiration': 'Expiration Date', | |
| 'contract_start': 'Start Date', | |
| 'industry_sector': 'Industry' | |
| } | |
| field_list = ', '.join([f"`{field_names.get(f, f)}`" for f in fields]) | |
| content_parts.append(f"**Showing {len(fields)} fields:** {field_list}\n") | |
| # Results table | |
| content_parts.append("### β Query Results") | |
| content_parts.append(format_result_table(result, limit=result_limit)) | |
| # Statistics | |
| stats = { | |
| 'success_rate': result.success_rate, | |
| 'total_rows': result.total_count, | |
| 'customers_queried': result.customers_queried, | |
| 'customers_succeeded': result.customers_succeeded, | |
| 'customers_failed': result.customers_failed, | |
| 'execution_time_ms': response['execution_time_ms'] | |
| } | |
| content_parts.append("\n" + format_statistics(stats)) | |
| # Debug info if enabled | |
| if debug_mode and 'debug' in response: | |
| # Add actual columns returned | |
| actual_columns = list(result.results[0].data.keys()) if result.results else [] | |
| debug_with_columns = response['debug'].copy() | |
| debug_with_columns['actual_columns'] = actual_columns | |
| content_parts.append("\n" + format_debug_info(debug_with_columns)) | |
| # Remove processing message and send result | |
| await processing_msg.remove() | |
| await cl.Message(content="\n\n".join(content_parts)).send() | |
| # Add action buttons | |
| actions = [ | |
| cl.Action(name="explain", payload={"action": "explain"}, label="π Explain Query"), | |
| cl.Action(name="feedback_good", payload={"action": "good"}, label="π Good Result"), | |
| cl.Action(name="feedback_bad", payload={"action": "incorrect"}, label="π Incorrect Result"), | |
| ] | |
| if not debug_mode: | |
| actions.insert(0, cl.Action(name="debug", payload={"action": "debug"}, label="π Show Debug Info")) | |
| await cl.Message(content="", actions=actions).send() | |
| else: | |
| # Format error response | |
| error_msg = f"""### β Query Failed | |
| **Error:** {response['error']} | |
| **Suggestions:** | |
| - Make sure your query is clear and specific | |
| - Try using simpler language | |
| - Use `/help` to see example queries | |
| - Enable debug mode with `/debug on` for more details | |
| """ | |
| await processing_msg.remove() | |
| await cl.Message(content=error_msg).send() | |
| except Exception as e: | |
| logger.error(f"Error processing query: {e}", exc_info=True) | |
| await processing_msg.remove() | |
| await cl.Message( | |
| content=f"β **Unexpected Error:** {str(e)}\n\nPlease try again or contact support." | |
| ).send() | |
| async def on_explain(action: cl.Action): | |
| """Handle explain action.""" | |
| orchestrator = cl.user_session.get("orchestrator") | |
| # Get the query from the previous message | |
| # For now, send a message asking to use /explain command | |
| await cl.Message( | |
| content="To explain a query, type: `/explain <your query>`" | |
| ).send() | |
| async def on_feedback_good(action: cl.Action): | |
| """Handle good feedback.""" | |
| await cl.Message(content="β Thank you for the feedback!").send() | |
| async def on_feedback_bad(action: cl.Action): | |
| """Handle bad feedback.""" | |
| await cl.Message( | |
| content="π Thank you for the feedback! We'll work on improving this. Please provide more details about what was incorrect." | |
| ).send() | |
| async def on_debug(action: cl.Action): | |
| """Handle debug action.""" | |
| await cl.Message( | |
| content="π Debug mode enabled for this session. Use `/debug on` to enable for all queries." | |
| ).send() | |
| async def handle_command(command: str, orchestrator: ChatOrchestrator, debug_mode: bool): | |
| """Handle special commands. | |
| Args: | |
| command: Command string | |
| orchestrator: ChatOrchestrator instance | |
| debug_mode: Current debug mode state | |
| """ | |
| parts = command.split(maxsplit=1) | |
| cmd = parts[0].lower() | |
| if cmd == "/help": | |
| help_msg = """### π Help | |
| **Example Queries:** | |
| - "Show me all contracts" | |
| - "Find active contracts" | |
| - "Count contracts by status" | |
| - "List contracts with value over 10000" | |
| - "Show contracts expiring soon" | |
| **Commands:** | |
| - `/customers` - List available customers | |
| - `/select <customer_id>` - Query specific customer(s) | |
| - `/select all` - Query all customers (default) | |
| - `/limit <number>` - Set max rows to display (default: 10) | |
| - `/debug on/off` - Toggle debug mode | |
| - `/stats` - Show query statistics | |
| - `/explain <query>` - Explain how a query will be processed | |
| - `/help` - Show this help message | |
| **Tips:** | |
| - Use natural language - I'll understand it! | |
| - Be specific about what you want to see | |
| - Use filters like "active", "over 1000", "expiring soon" | |
| - Enable debug mode to see SQL queries | |
| """ | |
| await cl.Message(content=help_msg).send() | |
| elif cmd == "/customers": | |
| customers = orchestrator.list_available_customers() | |
| # Get info for each customer | |
| customer_info = [] | |
| for customer_id in customers: | |
| info = orchestrator.get_customer_info(customer_id) | |
| if info['available']: | |
| customer_info.append( | |
| f"- **{customer_id}**: {info['total_rows']} rows, " | |
| f"{len(info['concepts'])} concepts mapped" | |
| ) | |
| content = f"### π₯ Available Customers ({len(customers)})\n\n" + "\n".join(customer_info) | |
| await cl.Message(content=content).send() | |
| elif cmd == "/debug": | |
| if len(parts) > 1: | |
| setting = parts[1].lower() | |
| if setting == "on": | |
| cl.user_session.set("debug_mode", True) | |
| await cl.Message(content="π Debug mode **enabled**. You'll see SQL queries and semantic plans.").send() | |
| elif setting == "off": | |
| cl.user_session.set("debug_mode", False) | |
| await cl.Message(content="π Debug mode **disabled**.").send() | |
| else: | |
| await cl.Message(content="β οΈ Use `/debug on` or `/debug off`").send() | |
| else: | |
| status = "enabled" if debug_mode else "disabled" | |
| await cl.Message(content=f"π Debug mode is currently **{status}**.").send() | |
| elif cmd == "/stats": | |
| stats = orchestrator.get_statistics() | |
| history_stats = f"""### π Query Statistics | |
| **Query History:** | |
| - Total Queries: {stats['total_queries']} | |
| - Successful: {stats['successful_queries']} ({stats['success_rate']:.1f}%) | |
| - Failed: {stats['failed_queries']} | |
| - Average Execution Time: {stats['average_execution_time_ms']:.2f}ms | |
| **Knowledge Graph:** | |
| - Total Concepts: {stats['knowledge_graph']['total_concepts']} | |
| - Total Customers: {stats['knowledge_graph']['total_customers']} | |
| - Total Mappings: {stats['knowledge_graph']['total_mappings']} | |
| - Total Transformations: {stats['knowledge_graph']['total_transformations']} | |
| """ | |
| await cl.Message(content=history_stats).send() | |
| elif cmd == "/explain": | |
| if len(parts) > 1: | |
| query = parts[1] | |
| try: | |
| explanation = orchestrator.explain_query(query) | |
| content = f"""### π Query Explanation | |
| **Query:** "{query}" | |
| **Explanation:** {explanation['explanation']} | |
| **Sample SQL:** | |
| ```sql | |
| {list(explanation['sample_sql'].values())[0] if explanation['sample_sql'] else 'N/A'} | |
| ``` | |
| """ | |
| await cl.Message(content=content).send() | |
| except Exception as e: | |
| await cl.Message(content=f"β Error explaining query: {str(e)}").send() | |
| else: | |
| await cl.Message(content="β οΈ Usage: `/explain <your query>`").send() | |
| elif cmd == "/select": | |
| if len(parts) > 1: | |
| selection = parts[1].lower() | |
| if selection == "all": | |
| cl.user_session.set("selected_customers", []) | |
| await cl.Message(content="β Now querying **all customers**.").send() | |
| else: | |
| # Parse customer IDs (comma-separated) | |
| customer_ids = [c.strip() for c in selection.split(",")] | |
| available = orchestrator.list_available_customers() | |
| # Validate | |
| invalid = [c for c in customer_ids if c not in available] | |
| if invalid: | |
| await cl.Message( | |
| content=f"β οΈ Invalid customer IDs: {', '.join(invalid)}\n\n" | |
| f"Available: {', '.join(available)}" | |
| ).send() | |
| else: | |
| cl.user_session.set("selected_customers", customer_ids) | |
| await cl.Message( | |
| content=f"β Now querying: **{', '.join(customer_ids)}**" | |
| ).send() | |
| else: | |
| selected = cl.user_session.get("selected_customers", []) | |
| if selected: | |
| await cl.Message(content=f"Currently querying: **{', '.join(selected)}**").send() | |
| else: | |
| await cl.Message(content="Currently querying: **all customers**").send() | |
| elif cmd == "/limit": | |
| if len(parts) > 1: | |
| try: | |
| limit = int(parts[1]) | |
| if limit < 1: | |
| await cl.Message(content="β οΈ Limit must be at least 1.").send() | |
| elif limit > 1000: | |
| await cl.Message(content="β οΈ Limit cannot exceed 1000 rows.").send() | |
| else: | |
| cl.user_session.set("result_limit", limit) | |
| await cl.Message(content=f"β Result limit set to **{limit} rows**.").send() | |
| except ValueError: | |
| await cl.Message(content="β οΈ Invalid number. Use `/limit <number>`").send() | |
| else: | |
| current_limit = cl.user_session.get("result_limit", 10) | |
| await cl.Message(content=f"Current result limit: **{current_limit} rows**").send() | |
| else: | |
| await cl.Message(content=f"β Unknown command: `{cmd}`. Type `/help` for available commands.").send() | |
| if __name__ == "__main__": | |
| # This is for development only | |
| # Use `chainlit run app.py` to start the server | |
| pass | |