sanzgiri's picture
v2.0: Dynamic result formatting and natural language customer selection
4bb196e
raw
history blame
24.5 kB
"""
Chainlit UI for Schema Translator
This module provides a chat interface for querying customer databases
using natural language queries that are automatically translated to SQL.
"""
import chainlit as cl
from typing import Optional, List, Dict, Any
import logging
from datetime import datetime
from schema_translator.orchestrator import ChatOrchestrator
from schema_translator.config import Config
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Global orchestrator instance
orchestrator: Optional[ChatOrchestrator] = None
def format_result_table(result, limit: int = 10) -> str:
"""Format harmonized result as a markdown table.
Args:
result: HarmonizedResult object
limit: Maximum number of rows to display (default 10)
Returns:
Markdown formatted table
"""
if not result.results:
return "*No results found*"
# Get column names from first row
if not result.results:
return "*No data*"
first_row = result.results[0]
all_columns = list(first_row.data.keys())
# Add source database to the data (from customer_id in HarmonizedRow)
# Convert customer_a -> A, customer_b -> B, etc. for display
for row in result.results:
if 'source_db' not in row.data:
# Extract letter from customer_id (customer_a -> A)
customer_letter = row.customer_id.replace('customer_', '').upper()
row.data['source_db'] = customer_letter
# Add source_db to all_columns if not already there
if 'source_db' not in all_columns:
all_columns.insert(0, 'source_db')
# Filter to only columns that have non-null values in at least one row
columns_with_values = set()
for col in all_columns:
if any(row.data.get(col) is not None for row in result.results):
columns_with_values.add(col)
# Check if multiple customers are being queried
multiple_customers = len(result.customers_queried) > 1
# Always include source_db if querying multiple customers
if multiple_customers and 'source_db' not in columns_with_values:
columns_with_values.add('source_db')
# Check if this is an aggregation/count query (no contract_identifier or all null)
is_aggregation = 'contract_identifier' not in columns_with_values
if is_aggregation:
# For aggregations, show source_db first, then all other non-null columns
columns = []
if 'source_db' in columns_with_values:
columns.append('source_db')
for col in columns_with_values:
if col != 'source_db':
columns.append(col)
# Nice names for aggregation fields
nice_names = {
'source_db': 'Customer',
'count_contract_identifier': 'Count',
'sum_contract_value': 'Total Value',
'avg_contract_value': 'Avg Value',
'max_contract_value': 'Max Value',
'min_contract_value': 'Min Value',
'count': 'Count',
'sum': 'Sum',
'average': 'Average',
'max': 'Max',
'min': 'Min'
}
else:
# For regular queries, use preferred field order (only for columns with values)
field_order = [
'source_db', # Show which customer (A, B, C, etc.)
'contract_identifier', # Contract ID
'contract_value',
'contract_status',
'contract_expiration',
'contract_start'
]
nice_names = {
'source_db': 'Customer', # Which database (A, B, C, D, E, F)
'contract_identifier': 'Contract ID',
'contract_value': 'Value',
'contract_status': 'Status',
'contract_expiration': 'Expiration',
'contract_start': 'Start Date'
}
# Order columns: preferred order first (only include if they have values)
columns = []
for field in field_order:
if field in columns_with_values:
columns.append(field)
# Add any remaining columns not in preferred order (that have values)
for col in columns_with_values:
if col not in columns:
columns.append(col)
# Build markdown table with better formatting
lines = []
# Header with nicer column names
display_cols = []
for col in columns:
# Get nice name or convert snake_case to Title Case
if col in nice_names:
display_cols.append(nice_names[col])
else:
# Convert snake_case to Title Case
display_cols.append(col.replace('_', ' ').title())
header = "| " + " | ".join(display_cols) + " |"
separator = "|" + "|".join(["---" for _ in columns]) + "|"
lines.append(header)
lines.append(separator)
# Rows (limit to specified number or total results, whichever is smaller)
max_rows = min(limit, len(result.results))
for row in result.results[:max_rows]:
values = []
for col in columns:
val = row.data.get(col, "")
# Format values for better display
if val is None or val == "None":
val = "β€”"
elif isinstance(val, (int, float)):
# Format numbers based on column type
if 'value' in col.lower() or 'sum' in col.lower():
# Format currency values
val = f"${val:,.0f}"
elif 'avg' in col.lower() or 'average' in col.lower():
# Format averages with 2 decimals
val = f"${val:,.2f}" if 'value' in col.lower() else f"{val:,.2f}"
else:
# Format counts and other integers
val = f"{val:,}"
else:
val = str(val)
values.append(val)
line = "| " + " | ".join(values) + " |"
lines.append(line)
if len(result.results) > max_rows:
lines.append(f"\n*Showing {max_rows} of {len(result.results)} rows. Use `/limit <number>` to adjust.*")
return "\n".join(lines)
def format_statistics(stats: Dict[str, Any]) -> str:
"""Format statistics as markdown.
Args:
stats: Statistics dictionary
Returns:
Formatted statistics
"""
lines = [
"### πŸ“Š Execution Statistics",
f"- **Success Rate:** {stats['success_rate']:.1f}%",
f"- **Total Rows:** {stats['total_rows']}",
f"- **Customers Queried:** {len(stats['customers_queried'])}",
f"- **Customers Succeeded:** {len(stats['customers_succeeded'])}",
f"- **Execution Time:** {stats['execution_time_ms']:.2f}ms"
]
if stats['customers_failed']:
lines.append(f"- **Customers Failed:** {', '.join(stats['customers_failed'])}")
return "\n".join(lines)
def format_debug_info(debug: Dict[str, Any]) -> str:
"""Format debug information as markdown.
Args:
debug: Debug information dictionary
Returns:
Formatted debug info
"""
projections = debug['semantic_plan']['projections']
proj_list = ', '.join(f'`{p}`' for p in projections) if projections else "*(none - SELECT ALL)*"
lines = [
"### πŸ” Debug Information",
"",
"**Semantic Plan:**",
f"- Intent: `{debug['semantic_plan']['intent']}`",
f"- Projections ({len(projections)}): {proj_list}",
]
if debug['semantic_plan']['filters']:
lines.append(f"- Filters: {len(debug['semantic_plan']['filters'])}")
for f in debug['semantic_plan']['filters']:
lines.append(f" - `{f['concept']}` {f['operator']} `{f['value']}`")
if debug['semantic_plan']['aggregations']:
lines.append(f"- Aggregations: {', '.join(debug['semantic_plan']['aggregations'])}")
# Show actual columns returned
if 'actual_columns' in debug:
actual_cols = ', '.join(f'`{c}`' for c in debug['actual_columns'])
lines.append(f"\n**Actual Columns Returned ({len(debug['actual_columns'])}):**")
lines.append(actual_cols)
# Show sample SQL for first customer
lines.append("\n**Sample SQL (first customer):**")
if debug['sql_queries']:
first_customer = list(debug['sql_queries'].keys())[0]
sql = debug['sql_queries'][first_customer]
lines.append(f"```sql\n{sql}\n```")
return "\n".join(lines)
@cl.on_chat_start
async def start():
"""Initialize the chat session."""
global orchestrator
# Initialize orchestrator
try:
config = Config()
orchestrator = ChatOrchestrator(use_llm=bool(config.anthropic_api_key))
# Store in session
cl.user_session.set("orchestrator", orchestrator)
cl.user_session.set("debug_mode", False)
cl.user_session.set("selected_customers", []) # Empty = all customers
cl.user_session.set("result_limit", 10) # Default rows to display
mode = "LLM mode" if orchestrator.use_llm else "Mock mode"
# Send welcome message
welcome_msg = f"""# 🎯 Schema Translator Chat
Welcome! I'm running in **{mode}**.
I can help you query customer databases using natural language. Just ask me questions like:
- "Show me all contracts"
- "Find active contracts with value over 10000"
- "Show contracts from customer A" or "from customer_a"
- "List contracts in customer B and C"
- "Count contracts by status"
**Note:** Results show Customer column (A, B, C, etc.) indicating which database each contract is from.
### Available Commands:
- `/customers` - List available databases (customer_a through customer_f)
- `/select <customer_id>` - Query specific database (e.g., `/select customer_a`)
- `/limit <number>` - Set max rows to display (default: 10)
- `/debug on/off` - Toggle debug mode
- `/stats` - Show query statistics
- `/help` - Show this help message
Try asking me a question!
"""
await cl.Message(content=welcome_msg).send()
# Get available customers
customers = orchestrator.list_available_customers()
logger.info(f"Chat started with {len(customers)} customers available")
except Exception as e:
logger.error(f"Failed to initialize orchestrator: {e}", exc_info=True)
await cl.Message(
content=f"❌ **Error:** Failed to initialize. Please check configuration.\n\n`{str(e)}`"
).send()
@cl.on_message
async def main(message: cl.Message):
"""Handle incoming chat messages."""
global orchestrator
orchestrator = cl.user_session.get("orchestrator")
debug_mode = cl.user_session.get("debug_mode", False)
selected_customers = cl.user_session.get("selected_customers", [])
result_limit = cl.user_session.get("result_limit", 10)
if not orchestrator:
await cl.Message(content="❌ Orchestrator not initialized. Please refresh.").send()
return
query_text = message.content.strip()
# Handle commands
if query_text.startswith("/"):
await handle_command(query_text, orchestrator, debug_mode)
return
# Validate query
if not query_text or len(query_text) < 3:
await cl.Message(content="⚠️ Please enter a valid query (at least 3 characters).").send()
return
# Show processing message
processing_msg = cl.Message(content="πŸ€” Processing your query...")
await processing_msg.send()
try:
# Execute query
customer_ids = selected_customers if selected_customers else None
response = orchestrator.process_query(
query_text,
customer_ids=customer_ids,
debug=debug_mode
)
if response['success']:
# Format successful response
result = response['result']
# Build response message
content_parts = []
# Show available fields
if result.results:
fields = list(result.results[0].data.keys())
# Filter out None values to get actual fields returned
actual_fields = [f for f in fields if any(
row.data.get(f) is not None for row in result.results
)]
field_names = {
'source_db': 'Customer',
'contract_identifier': 'Contract ID',
'contract_value': 'Value',
'contract_status': 'Status',
'contract_expiration': 'Expiration Date',
'contract_start': 'Start Date',
'count_contract_identifier': 'Count',
'sum_contract_value': 'Total Value',
'avg_contract_value': 'Avg Value',
'max_contract_value': 'Max Value',
'min_contract_value': 'Min Value'
}
field_display = []
for f in actual_fields:
nice_name = field_names.get(f, f.replace('_', ' ').title())
field_display.append(f"`{nice_name}`")
field_list = ', '.join(field_display)
content_parts.append(f"**Showing {len(actual_fields)} fields:** {field_list}\n")
# Results table
content_parts.append("### βœ… Query Results")
content_parts.append(format_result_table(result, limit=result_limit))
# Statistics
stats = {
'success_rate': result.success_rate,
'total_rows': result.total_count,
'customers_queried': result.customers_queried,
'customers_succeeded': result.customers_succeeded,
'customers_failed': result.customers_failed,
'execution_time_ms': response['execution_time_ms']
}
content_parts.append("\n" + format_statistics(stats))
# Debug info if enabled
if debug_mode and 'debug' in response:
# Add actual columns returned
actual_columns = list(result.results[0].data.keys()) if result.results else []
debug_with_columns = response['debug'].copy()
debug_with_columns['actual_columns'] = actual_columns
content_parts.append("\n" + format_debug_info(debug_with_columns))
# Remove processing message and send result
await processing_msg.remove()
await cl.Message(content="\n\n".join(content_parts)).send()
# Add action buttons
actions = [
cl.Action(name="explain", payload={"action": "explain"}, label="πŸ“– Explain Query"),
cl.Action(name="feedback_good", payload={"action": "good"}, label="πŸ‘ Good Result"),
cl.Action(name="feedback_bad", payload={"action": "incorrect"}, label="πŸ‘Ž Incorrect Result"),
]
if not debug_mode:
actions.insert(0, cl.Action(name="debug", payload={"action": "debug"}, label="πŸ” Show Debug Info"))
await cl.Message(content="", actions=actions).send()
else:
# Format error response
error_msg = f"""### ❌ Query Failed
**Error:** {response['error']}
**Suggestions:**
- Make sure your query is clear and specific
- Try using simpler language
- Use `/help` to see example queries
- Enable debug mode with `/debug on` for more details
"""
await processing_msg.remove()
await cl.Message(content=error_msg).send()
except Exception as e:
logger.error(f"Error processing query: {e}", exc_info=True)
await processing_msg.remove()
await cl.Message(
content=f"❌ **Unexpected Error:** {str(e)}\n\nPlease try again or contact support."
).send()
@cl.action_callback("explain")
async def on_explain(action: cl.Action):
"""Handle explain action."""
orchestrator = cl.user_session.get("orchestrator")
# Get the query from the previous message
# For now, send a message asking to use /explain command
await cl.Message(
content="To explain a query, type: `/explain <your query>`"
).send()
@cl.action_callback("feedback_good")
async def on_feedback_good(action: cl.Action):
"""Handle good feedback."""
await cl.Message(content="βœ… Thank you for the feedback!").send()
@cl.action_callback("feedback_bad")
async def on_feedback_bad(action: cl.Action):
"""Handle bad feedback."""
await cl.Message(
content="πŸ“ Thank you for the feedback! We'll work on improving this. Please provide more details about what was incorrect."
).send()
@cl.action_callback("debug")
async def on_debug(action: cl.Action):
"""Handle debug action."""
await cl.Message(
content="πŸ” Debug mode enabled for this session. Use `/debug on` to enable for all queries."
).send()
async def handle_command(command: str, orchestrator: ChatOrchestrator, debug_mode: bool):
"""Handle special commands.
Args:
command: Command string
orchestrator: ChatOrchestrator instance
debug_mode: Current debug mode state
"""
parts = command.split(maxsplit=1)
cmd = parts[0].lower()
if cmd == "/help":
help_msg = """### πŸ“š Help
**Example Queries:**
- "Show me all contracts"
- "Find active contracts"
- "Show contracts from customer A" or "from customer_a"
- "Query customer B and customer C databases"
- "Count contracts by status"
- "List contracts with value over 10000"
- "Show active contracts in customer_a expiring in 2026"
**Important Notes:**
- You can query specific customers by saying "customer A", "customer_a", "from customer B", etc.
- Without specifying, queries run across all available databases
- **Customer** column in results shows which database (A, B, C, etc.) each contract came from
**Commands:**
- `/customers` - List available databases (customer_a through customer_f)
- `/select <customer_id>` - Query specific database (e.g., `/select customer_a`)
- `/select all` - Query all databases (default)
- `/limit <number>` - Set max rows to display (default: 10)
- `/debug on/off` - Toggle debug mode
- `/stats` - Show query statistics
- `/explain <query>` - Explain how a query will be processed
- `/help` - Show this help message
**Tips:**
- Use natural language - I'll understand it!
- Be specific about what you want to see
- Use filters like "active", "over 10000", "expiring in 90 days"
- Enable debug mode to see SQL queries and semantic plans
"""
await cl.Message(content=help_msg).send()
elif cmd == "/customers":
customers = orchestrator.list_available_customers()
# Get info for each customer
customer_info = []
for customer_id in customers:
info = orchestrator.get_customer_info(customer_id)
if info['available']:
customer_info.append(
f"- **{customer_id}**: {info['total_rows']} rows, "
f"{len(info['concepts'])} concepts mapped"
)
content = f"### πŸ‘₯ Available Customers ({len(customers)})\n\n" + "\n".join(customer_info)
await cl.Message(content=content).send()
elif cmd == "/debug":
if len(parts) > 1:
setting = parts[1].lower()
if setting == "on":
cl.user_session.set("debug_mode", True)
await cl.Message(content="πŸ” Debug mode **enabled**. You'll see SQL queries and semantic plans.").send()
elif setting == "off":
cl.user_session.set("debug_mode", False)
await cl.Message(content="πŸ” Debug mode **disabled**.").send()
else:
await cl.Message(content="⚠️ Use `/debug on` or `/debug off`").send()
else:
status = "enabled" if debug_mode else "disabled"
await cl.Message(content=f"πŸ” Debug mode is currently **{status}**.").send()
elif cmd == "/stats":
stats = orchestrator.get_statistics()
history_stats = f"""### πŸ“Š Query Statistics
**Query History:**
- Total Queries: {stats['total_queries']}
- Successful: {stats['successful_queries']} ({stats['success_rate']:.1f}%)
- Failed: {stats['failed_queries']}
- Average Execution Time: {stats['average_execution_time_ms']:.2f}ms
**Knowledge Graph:**
- Total Concepts: {stats['knowledge_graph']['total_concepts']}
- Total Customers: {stats['knowledge_graph']['total_customers']}
- Total Mappings: {stats['knowledge_graph']['total_mappings']}
- Total Transformations: {stats['knowledge_graph']['total_transformations']}
"""
await cl.Message(content=history_stats).send()
elif cmd == "/explain":
if len(parts) > 1:
query = parts[1]
try:
explanation = orchestrator.explain_query(query)
content = f"""### πŸ“– Query Explanation
**Query:** "{query}"
**Explanation:** {explanation['explanation']}
**Sample SQL:**
```sql
{list(explanation['sample_sql'].values())[0] if explanation['sample_sql'] else 'N/A'}
```
"""
await cl.Message(content=content).send()
except Exception as e:
await cl.Message(content=f"❌ Error explaining query: {str(e)}").send()
else:
await cl.Message(content="⚠️ Usage: `/explain <your query>`").send()
elif cmd == "/select":
if len(parts) > 1:
selection = parts[1].lower()
if selection == "all":
cl.user_session.set("selected_customers", [])
await cl.Message(content="βœ… Now querying **all customers**.").send()
else:
# Parse customer IDs (comma-separated)
customer_ids = [c.strip() for c in selection.split(",")]
available = orchestrator.list_available_customers()
# Validate
invalid = [c for c in customer_ids if c not in available]
if invalid:
await cl.Message(
content=f"⚠️ Invalid customer IDs: {', '.join(invalid)}\n\n"
f"Available: {', '.join(available)}"
).send()
else:
cl.user_session.set("selected_customers", customer_ids)
await cl.Message(
content=f"βœ… Now querying: **{', '.join(customer_ids)}**"
).send()
else:
selected = cl.user_session.get("selected_customers", [])
if selected:
await cl.Message(content=f"Currently querying: **{', '.join(selected)}**").send()
else:
await cl.Message(content="Currently querying: **all customers**").send()
elif cmd == "/limit":
if len(parts) > 1:
try:
limit = int(parts[1])
if limit < 1:
await cl.Message(content="⚠️ Limit must be at least 1.").send()
elif limit > 1000:
await cl.Message(content="⚠️ Limit cannot exceed 1000 rows.").send()
else:
cl.user_session.set("result_limit", limit)
await cl.Message(content=f"βœ… Result limit set to **{limit} rows**.").send()
except ValueError:
await cl.Message(content="⚠️ Invalid number. Use `/limit <number>`").send()
else:
current_limit = cl.user_session.get("result_limit", 10)
await cl.Message(content=f"Current result limit: **{current_limit} rows**").send()
else:
await cl.Message(content=f"❓ Unknown command: `{cmd}`. Type `/help` for available commands.").send()
if __name__ == "__main__":
# This is for development only
# Use `chainlit run app.py` to start the server
pass