Spaces:
Paused
v2.0: Dynamic result formatting and natural language customer selection
Browse filesMajor features:
- Added target_customers field to SemanticQueryPlan for natural language database selection
- Updated query understanding agent to extract customer references (customer_a, customer A, etc.)
- Modified orchestrator to use extracted target_customers from query plan
- Fixed result harmonizer to include aggregation columns (count, sum, avg, etc.)
- Implemented dynamic table formatting in app.py that filters null columns
- Customer column now displays as A, B, C instead of customer_a, customer_b
- Removed customer_name and industry_sector concepts from knowledge graph
- Updated all tests to reflect schema changes (156 tests passing)
Key improvements:
- Queries like 'show contracts from customer A' now work naturally
- Count/sum/aggregation queries display correctly with only relevant columns
- No more empty columns with dashes in result tables
- Smart Customer column display (shown when querying multiple databases)
- All query types supported: regular, count, sum, filtered, multi-customer
- app.py +138 -48
- initialize_kg.py +6 -130
- knowledge_graph.json +0 -201
- schema_translator/agents/query_understanding.py +31 -11
- schema_translator/mock_data.py +30 -77
- schema_translator/models.py +6 -2
- schema_translator/orchestrator.py +13 -6
- schema_translator/result_harmonizer.py +9 -0
- tests/test_agents.py +2 -5
- tests/test_knowledge_graph.py +4 -3
- tests/test_models.py +0 -1
- tests/test_query_execution.py +13 -13
- tests/test_result_harmonization.py +22 -22
|
@@ -45,44 +45,101 @@ def format_result_table(result, limit: int = 10) -> str:
|
|
| 45 |
first_row = result.results[0]
|
| 46 |
all_columns = list(first_row.data.keys())
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
'
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
'customer_name': 'Company', # Changed from Customer to Company
|
| 63 |
-
'contract_value': 'Value',
|
| 64 |
-
'contract_status': 'Status',
|
| 65 |
-
'contract_expiration': 'Expiration',
|
| 66 |
-
'contract_start': 'Start Date',
|
| 67 |
-
'industry_sector': 'Industry'
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
# Order columns: preferred order first, then any remaining
|
| 71 |
-
columns = []
|
| 72 |
-
for field in field_order:
|
| 73 |
-
if field in all_columns:
|
| 74 |
-
columns.append(field)
|
| 75 |
-
|
| 76 |
-
# Add any remaining columns not in preferred order
|
| 77 |
for col in all_columns:
|
| 78 |
-
if col not in
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
# Build markdown table with better formatting
|
| 82 |
lines = []
|
| 83 |
|
| 84 |
# Header with nicer column names
|
| 85 |
-
display_cols = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
header = "| " + " | ".join(display_cols) + " |"
|
| 87 |
separator = "|" + "|".join(["---" for _ in columns]) + "|"
|
| 88 |
lines.append(header)
|
|
@@ -97,9 +154,17 @@ def format_result_table(result, limit: int = 10) -> str:
|
|
| 97 |
# Format values for better display
|
| 98 |
if val is None or val == "None":
|
| 99 |
val = "—"
|
| 100 |
-
elif isinstance(val, (int, float))
|
| 101 |
-
# Format
|
| 102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
else:
|
| 104 |
val = str(val)
|
| 105 |
values.append(val)
|
|
@@ -206,11 +271,15 @@ Welcome! I'm running in **{mode}**.
|
|
| 206 |
I can help you query customer databases using natural language. Just ask me questions like:
|
| 207 |
- "Show me all contracts"
|
| 208 |
- "Find active contracts with value over 10000"
|
|
|
|
|
|
|
| 209 |
- "Count contracts by status"
|
| 210 |
-
|
|
|
|
| 211 |
|
| 212 |
### Available Commands:
|
| 213 |
-
- `/customers` - List available
|
|
|
|
| 214 |
- `/limit <number>` - Set max rows to display (default: 10)
|
| 215 |
- `/debug on/off` - Toggle debug mode
|
| 216 |
- `/stats` - Show query statistics
|
|
@@ -282,18 +351,32 @@ async def main(message: cl.Message):
|
|
| 282 |
# Show available fields
|
| 283 |
if result.results:
|
| 284 |
fields = list(result.results[0].data.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
field_names = {
|
|
|
|
| 286 |
'contract_identifier': 'Contract ID',
|
| 287 |
-
'contract_name': 'Contract Name',
|
| 288 |
-
'customer_name': 'Company', # Changed from Customer to Company
|
| 289 |
'contract_value': 'Value',
|
| 290 |
'contract_status': 'Status',
|
| 291 |
'contract_expiration': 'Expiration Date',
|
| 292 |
'contract_start': 'Start Date',
|
| 293 |
-
'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
}
|
| 295 |
-
|
| 296 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
# Results table
|
| 299 |
content_parts.append("### ✅ Query Results")
|
|
@@ -408,14 +491,21 @@ async def handle_command(command: str, orchestrator: ChatOrchestrator, debug_mod
|
|
| 408 |
**Example Queries:**
|
| 409 |
- "Show me all contracts"
|
| 410 |
- "Find active contracts"
|
|
|
|
|
|
|
| 411 |
- "Count contracts by status"
|
| 412 |
- "List contracts with value over 10000"
|
| 413 |
-
- "Show contracts expiring
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 414 |
|
| 415 |
**Commands:**
|
| 416 |
-
- `/customers` - List available
|
| 417 |
-
- `/select <customer_id>` - Query specific
|
| 418 |
-
- `/select all` - Query all
|
| 419 |
- `/limit <number>` - Set max rows to display (default: 10)
|
| 420 |
- `/debug on/off` - Toggle debug mode
|
| 421 |
- `/stats` - Show query statistics
|
|
@@ -425,8 +515,8 @@ async def handle_command(command: str, orchestrator: ChatOrchestrator, debug_mod
|
|
| 425 |
**Tips:**
|
| 426 |
- Use natural language - I'll understand it!
|
| 427 |
- Be specific about what you want to see
|
| 428 |
-
- Use filters like "active", "over
|
| 429 |
-
- Enable debug mode to see SQL queries
|
| 430 |
"""
|
| 431 |
await cl.Message(content=help_msg).send()
|
| 432 |
|
|
|
|
| 45 |
first_row = result.results[0]
|
| 46 |
all_columns = list(first_row.data.keys())
|
| 47 |
|
| 48 |
+
# Add source database to the data (from customer_id in HarmonizedRow)
|
| 49 |
+
# Convert customer_a -> A, customer_b -> B, etc. for display
|
| 50 |
+
for row in result.results:
|
| 51 |
+
if 'source_db' not in row.data:
|
| 52 |
+
# Extract letter from customer_id (customer_a -> A)
|
| 53 |
+
customer_letter = row.customer_id.replace('customer_', '').upper()
|
| 54 |
+
row.data['source_db'] = customer_letter
|
| 55 |
+
|
| 56 |
+
# Add source_db to all_columns if not already there
|
| 57 |
+
if 'source_db' not in all_columns:
|
| 58 |
+
all_columns.insert(0, 'source_db')
|
| 59 |
+
|
| 60 |
+
# Filter to only columns that have non-null values in at least one row
|
| 61 |
+
columns_with_values = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
for col in all_columns:
|
| 63 |
+
if any(row.data.get(col) is not None for row in result.results):
|
| 64 |
+
columns_with_values.add(col)
|
| 65 |
+
|
| 66 |
+
# Check if multiple customers are being queried
|
| 67 |
+
multiple_customers = len(result.customers_queried) > 1
|
| 68 |
+
|
| 69 |
+
# Always include source_db if querying multiple customers
|
| 70 |
+
if multiple_customers and 'source_db' not in columns_with_values:
|
| 71 |
+
columns_with_values.add('source_db')
|
| 72 |
+
|
| 73 |
+
# Check if this is an aggregation/count query (no contract_identifier or all null)
|
| 74 |
+
is_aggregation = 'contract_identifier' not in columns_with_values
|
| 75 |
+
|
| 76 |
+
if is_aggregation:
|
| 77 |
+
# For aggregations, show source_db first, then all other non-null columns
|
| 78 |
+
columns = []
|
| 79 |
+
if 'source_db' in columns_with_values:
|
| 80 |
+
columns.append('source_db')
|
| 81 |
+
for col in columns_with_values:
|
| 82 |
+
if col != 'source_db':
|
| 83 |
+
columns.append(col)
|
| 84 |
+
|
| 85 |
+
# Nice names for aggregation fields
|
| 86 |
+
nice_names = {
|
| 87 |
+
'source_db': 'Customer',
|
| 88 |
+
'count_contract_identifier': 'Count',
|
| 89 |
+
'sum_contract_value': 'Total Value',
|
| 90 |
+
'avg_contract_value': 'Avg Value',
|
| 91 |
+
'max_contract_value': 'Max Value',
|
| 92 |
+
'min_contract_value': 'Min Value',
|
| 93 |
+
'count': 'Count',
|
| 94 |
+
'sum': 'Sum',
|
| 95 |
+
'average': 'Average',
|
| 96 |
+
'max': 'Max',
|
| 97 |
+
'min': 'Min'
|
| 98 |
+
}
|
| 99 |
+
else:
|
| 100 |
+
# For regular queries, use preferred field order (only for columns with values)
|
| 101 |
+
field_order = [
|
| 102 |
+
'source_db', # Show which customer (A, B, C, etc.)
|
| 103 |
+
'contract_identifier', # Contract ID
|
| 104 |
+
'contract_value',
|
| 105 |
+
'contract_status',
|
| 106 |
+
'contract_expiration',
|
| 107 |
+
'contract_start'
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
nice_names = {
|
| 111 |
+
'source_db': 'Customer', # Which database (A, B, C, D, E, F)
|
| 112 |
+
'contract_identifier': 'Contract ID',
|
| 113 |
+
'contract_value': 'Value',
|
| 114 |
+
'contract_status': 'Status',
|
| 115 |
+
'contract_expiration': 'Expiration',
|
| 116 |
+
'contract_start': 'Start Date'
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
# Order columns: preferred order first (only include if they have values)
|
| 120 |
+
columns = []
|
| 121 |
+
for field in field_order:
|
| 122 |
+
if field in columns_with_values:
|
| 123 |
+
columns.append(field)
|
| 124 |
+
|
| 125 |
+
# Add any remaining columns not in preferred order (that have values)
|
| 126 |
+
for col in columns_with_values:
|
| 127 |
+
if col not in columns:
|
| 128 |
+
columns.append(col)
|
| 129 |
|
| 130 |
# Build markdown table with better formatting
|
| 131 |
lines = []
|
| 132 |
|
| 133 |
# Header with nicer column names
|
| 134 |
+
display_cols = []
|
| 135 |
+
for col in columns:
|
| 136 |
+
# Get nice name or convert snake_case to Title Case
|
| 137 |
+
if col in nice_names:
|
| 138 |
+
display_cols.append(nice_names[col])
|
| 139 |
+
else:
|
| 140 |
+
# Convert snake_case to Title Case
|
| 141 |
+
display_cols.append(col.replace('_', ' ').title())
|
| 142 |
+
|
| 143 |
header = "| " + " | ".join(display_cols) + " |"
|
| 144 |
separator = "|" + "|".join(["---" for _ in columns]) + "|"
|
| 145 |
lines.append(header)
|
|
|
|
| 154 |
# Format values for better display
|
| 155 |
if val is None or val == "None":
|
| 156 |
val = "—"
|
| 157 |
+
elif isinstance(val, (int, float)):
|
| 158 |
+
# Format numbers based on column type
|
| 159 |
+
if 'value' in col.lower() or 'sum' in col.lower():
|
| 160 |
+
# Format currency values
|
| 161 |
+
val = f"${val:,.0f}"
|
| 162 |
+
elif 'avg' in col.lower() or 'average' in col.lower():
|
| 163 |
+
# Format averages with 2 decimals
|
| 164 |
+
val = f"${val:,.2f}" if 'value' in col.lower() else f"{val:,.2f}"
|
| 165 |
+
else:
|
| 166 |
+
# Format counts and other integers
|
| 167 |
+
val = f"{val:,}"
|
| 168 |
else:
|
| 169 |
val = str(val)
|
| 170 |
values.append(val)
|
|
|
|
| 271 |
I can help you query customer databases using natural language. Just ask me questions like:
|
| 272 |
- "Show me all contracts"
|
| 273 |
- "Find active contracts with value over 10000"
|
| 274 |
+
- "Show contracts from customer A" or "from customer_a"
|
| 275 |
+
- "List contracts in customer B and C"
|
| 276 |
- "Count contracts by status"
|
| 277 |
+
|
| 278 |
+
**Note:** Results show Customer column (A, B, C, etc.) indicating which database each contract is from.
|
| 279 |
|
| 280 |
### Available Commands:
|
| 281 |
+
- `/customers` - List available databases (customer_a through customer_f)
|
| 282 |
+
- `/select <customer_id>` - Query specific database (e.g., `/select customer_a`)
|
| 283 |
- `/limit <number>` - Set max rows to display (default: 10)
|
| 284 |
- `/debug on/off` - Toggle debug mode
|
| 285 |
- `/stats` - Show query statistics
|
|
|
|
| 351 |
# Show available fields
|
| 352 |
if result.results:
|
| 353 |
fields = list(result.results[0].data.keys())
|
| 354 |
+
# Filter out None values to get actual fields returned
|
| 355 |
+
actual_fields = [f for f in fields if any(
|
| 356 |
+
row.data.get(f) is not None for row in result.results
|
| 357 |
+
)]
|
| 358 |
+
|
| 359 |
field_names = {
|
| 360 |
+
'source_db': 'Customer',
|
| 361 |
'contract_identifier': 'Contract ID',
|
|
|
|
|
|
|
| 362 |
'contract_value': 'Value',
|
| 363 |
'contract_status': 'Status',
|
| 364 |
'contract_expiration': 'Expiration Date',
|
| 365 |
'contract_start': 'Start Date',
|
| 366 |
+
'count_contract_identifier': 'Count',
|
| 367 |
+
'sum_contract_value': 'Total Value',
|
| 368 |
+
'avg_contract_value': 'Avg Value',
|
| 369 |
+
'max_contract_value': 'Max Value',
|
| 370 |
+
'min_contract_value': 'Min Value'
|
| 371 |
}
|
| 372 |
+
|
| 373 |
+
field_display = []
|
| 374 |
+
for f in actual_fields:
|
| 375 |
+
nice_name = field_names.get(f, f.replace('_', ' ').title())
|
| 376 |
+
field_display.append(f"`{nice_name}`")
|
| 377 |
+
|
| 378 |
+
field_list = ', '.join(field_display)
|
| 379 |
+
content_parts.append(f"**Showing {len(actual_fields)} fields:** {field_list}\n")
|
| 380 |
|
| 381 |
# Results table
|
| 382 |
content_parts.append("### ✅ Query Results")
|
|
|
|
| 491 |
**Example Queries:**
|
| 492 |
- "Show me all contracts"
|
| 493 |
- "Find active contracts"
|
| 494 |
+
- "Show contracts from customer A" or "from customer_a"
|
| 495 |
+
- "Query customer B and customer C databases"
|
| 496 |
- "Count contracts by status"
|
| 497 |
- "List contracts with value over 10000"
|
| 498 |
+
- "Show active contracts in customer_a expiring in 2026"
|
| 499 |
+
|
| 500 |
+
**Important Notes:**
|
| 501 |
+
- You can query specific customers by saying "customer A", "customer_a", "from customer B", etc.
|
| 502 |
+
- Without specifying, queries run across all available databases
|
| 503 |
+
- **Customer** column in results shows which database (A, B, C, etc.) each contract came from
|
| 504 |
|
| 505 |
**Commands:**
|
| 506 |
+
- `/customers` - List available databases (customer_a through customer_f)
|
| 507 |
+
- `/select <customer_id>` - Query specific database (e.g., `/select customer_a`)
|
| 508 |
+
- `/select all` - Query all databases (default)
|
| 509 |
- `/limit <number>` - Set max rows to display (default: 10)
|
| 510 |
- `/debug on/off` - Toggle debug mode
|
| 511 |
- `/stats` - Show query statistics
|
|
|
|
| 515 |
**Tips:**
|
| 516 |
- Use natural language - I'll understand it!
|
| 517 |
- Be specific about what you want to see
|
| 518 |
+
- Use filters like "active", "over 10000", "expiring in 90 days"
|
| 519 |
+
- Enable debug mode to see SQL queries and semantic plans
|
| 520 |
"""
|
| 521 |
await cl.Message(content=help_msg).send()
|
| 522 |
|
|
@@ -280,137 +280,13 @@ def initialize_knowledge_graph() -> SchemaKnowledgeGraph:
|
|
| 280 |
)
|
| 281 |
|
| 282 |
# ========================================================================
|
| 283 |
-
# 5.
|
| 284 |
# ========================================================================
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
aliases=["industry", "sector", "vertical", "business_sector"]
|
| 291 |
-
)
|
| 292 |
-
|
| 293 |
-
kg.add_customer_mapping(
|
| 294 |
-
concept_id="industry_sector",
|
| 295 |
-
customer_id="customer_a",
|
| 296 |
-
table_name="contracts",
|
| 297 |
-
column_name="industry",
|
| 298 |
-
data_type="TEXT",
|
| 299 |
-
semantic_type=SemanticType.TEXT
|
| 300 |
-
)
|
| 301 |
-
|
| 302 |
-
kg.add_customer_mapping(
|
| 303 |
-
concept_id="industry_sector",
|
| 304 |
-
customer_id="customer_b",
|
| 305 |
-
table_name="contract_headers",
|
| 306 |
-
column_name="sector",
|
| 307 |
-
data_type="TEXT",
|
| 308 |
-
semantic_type=SemanticType.TEXT
|
| 309 |
-
)
|
| 310 |
-
|
| 311 |
-
kg.add_customer_mapping(
|
| 312 |
-
concept_id="industry_sector",
|
| 313 |
-
customer_id="customer_c",
|
| 314 |
-
table_name="contracts",
|
| 315 |
-
column_name="business_sector",
|
| 316 |
-
data_type="TEXT",
|
| 317 |
-
semantic_type=SemanticType.TEXT
|
| 318 |
-
)
|
| 319 |
-
|
| 320 |
-
kg.add_customer_mapping(
|
| 321 |
-
concept_id="industry_sector",
|
| 322 |
-
customer_id="customer_d",
|
| 323 |
-
table_name="contracts",
|
| 324 |
-
column_name="industry",
|
| 325 |
-
data_type="TEXT",
|
| 326 |
-
semantic_type=SemanticType.TEXT
|
| 327 |
-
)
|
| 328 |
-
|
| 329 |
-
kg.add_customer_mapping(
|
| 330 |
-
concept_id="industry_sector",
|
| 331 |
-
customer_id="customer_e",
|
| 332 |
-
table_name="contracts",
|
| 333 |
-
column_name="industry",
|
| 334 |
-
data_type="TEXT",
|
| 335 |
-
semantic_type=SemanticType.TEXT
|
| 336 |
-
)
|
| 337 |
-
|
| 338 |
-
kg.add_customer_mapping(
|
| 339 |
-
concept_id="industry_sector",
|
| 340 |
-
customer_id="customer_f",
|
| 341 |
-
table_name="contracts",
|
| 342 |
-
column_name="sector",
|
| 343 |
-
data_type="TEXT",
|
| 344 |
-
semantic_type=SemanticType.TEXT
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
# ========================================================================
|
| 348 |
-
# 6. CUSTOMER NAME
|
| 349 |
-
# ========================================================================
|
| 350 |
-
print(" Adding concept: customer_name")
|
| 351 |
-
kg.add_concept(
|
| 352 |
-
concept_id="customer_name",
|
| 353 |
-
concept_name="Customer Name",
|
| 354 |
-
description="Name of the customer/client organization",
|
| 355 |
-
aliases=["client", "account", "organization", "client_name"]
|
| 356 |
-
)
|
| 357 |
-
|
| 358 |
-
kg.add_customer_mapping(
|
| 359 |
-
concept_id="customer_name",
|
| 360 |
-
customer_id="customer_a",
|
| 361 |
-
table_name="contracts",
|
| 362 |
-
column_name="customer_name",
|
| 363 |
-
data_type="TEXT",
|
| 364 |
-
semantic_type=SemanticType.TEXT
|
| 365 |
-
)
|
| 366 |
-
|
| 367 |
-
kg.add_customer_mapping(
|
| 368 |
-
concept_id="customer_name",
|
| 369 |
-
customer_id="customer_b",
|
| 370 |
-
table_name="contract_headers",
|
| 371 |
-
column_name="client_name",
|
| 372 |
-
data_type="TEXT",
|
| 373 |
-
semantic_type=SemanticType.TEXT
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
kg.add_customer_mapping(
|
| 377 |
-
concept_id="customer_name",
|
| 378 |
-
customer_id="customer_c",
|
| 379 |
-
table_name="contracts",
|
| 380 |
-
column_name="account",
|
| 381 |
-
data_type="TEXT",
|
| 382 |
-
semantic_type=SemanticType.TEXT
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
kg.add_customer_mapping(
|
| 386 |
-
concept_id="customer_name",
|
| 387 |
-
customer_id="customer_d",
|
| 388 |
-
table_name="contracts",
|
| 389 |
-
column_name="customer_org",
|
| 390 |
-
data_type="TEXT",
|
| 391 |
-
semantic_type=SemanticType.TEXT
|
| 392 |
-
)
|
| 393 |
-
|
| 394 |
-
kg.add_customer_mapping(
|
| 395 |
-
concept_id="customer_name",
|
| 396 |
-
customer_id="customer_e",
|
| 397 |
-
table_name="contracts",
|
| 398 |
-
column_name="customer_name",
|
| 399 |
-
data_type="TEXT",
|
| 400 |
-
semantic_type=SemanticType.TEXT
|
| 401 |
-
)
|
| 402 |
-
|
| 403 |
-
kg.add_customer_mapping(
|
| 404 |
-
concept_id="customer_name",
|
| 405 |
-
customer_id="customer_f",
|
| 406 |
-
table_name="contracts",
|
| 407 |
-
column_name="account",
|
| 408 |
-
data_type="TEXT",
|
| 409 |
-
semantic_type=SemanticType.TEXT
|
| 410 |
-
)
|
| 411 |
-
|
| 412 |
-
# ========================================================================
|
| 413 |
-
# 7. CONTRACT START
|
| 414 |
# ========================================================================
|
| 415 |
print(" Adding concept: contract_start")
|
| 416 |
kg.add_concept(
|
|
|
|
| 280 |
)
|
| 281 |
|
| 282 |
# ========================================================================
|
| 283 |
+
# 5. CONTRACT START
|
| 284 |
# ========================================================================
|
| 285 |
+
# NOTE: industry_sector and customer_name concepts were removed.
|
| 286 |
+
# - customer_name: Refers to company names in contracts (e.g., "Global Tech Inc")
|
| 287 |
+
# which conflicts with database selection (customer_a, customer_b, etc.)
|
| 288 |
+
# - industry_sector: Internal attribute that doesn't represent typical
|
| 289 |
+
# business queries users would make.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
# ========================================================================
|
| 291 |
print(" Adding concept: contract_start")
|
| 292 |
kg.add_concept(
|
|
@@ -271,140 +271,6 @@
|
|
| 271 |
}
|
| 272 |
}
|
| 273 |
},
|
| 274 |
-
"industry_sector": {
|
| 275 |
-
"concept_id": "industry_sector",
|
| 276 |
-
"concept_name": "Industry Sector",
|
| 277 |
-
"description": "Business industry or vertical of the customer",
|
| 278 |
-
"aliases": [
|
| 279 |
-
"industry",
|
| 280 |
-
"sector",
|
| 281 |
-
"vertical",
|
| 282 |
-
"business_sector"
|
| 283 |
-
],
|
| 284 |
-
"customer_mappings": {
|
| 285 |
-
"customer_a": {
|
| 286 |
-
"customer_id": "customer_a",
|
| 287 |
-
"table_name": "contracts",
|
| 288 |
-
"column_name": "industry",
|
| 289 |
-
"data_type": "TEXT",
|
| 290 |
-
"semantic_type": "text",
|
| 291 |
-
"transformation": null,
|
| 292 |
-
"join_requirements": []
|
| 293 |
-
},
|
| 294 |
-
"customer_b": {
|
| 295 |
-
"customer_id": "customer_b",
|
| 296 |
-
"table_name": "contract_headers",
|
| 297 |
-
"column_name": "sector",
|
| 298 |
-
"data_type": "TEXT",
|
| 299 |
-
"semantic_type": "text",
|
| 300 |
-
"transformation": null,
|
| 301 |
-
"join_requirements": []
|
| 302 |
-
},
|
| 303 |
-
"customer_c": {
|
| 304 |
-
"customer_id": "customer_c",
|
| 305 |
-
"table_name": "contracts",
|
| 306 |
-
"column_name": "business_sector",
|
| 307 |
-
"data_type": "TEXT",
|
| 308 |
-
"semantic_type": "text",
|
| 309 |
-
"transformation": null,
|
| 310 |
-
"join_requirements": []
|
| 311 |
-
},
|
| 312 |
-
"customer_d": {
|
| 313 |
-
"customer_id": "customer_d",
|
| 314 |
-
"table_name": "contracts",
|
| 315 |
-
"column_name": "industry",
|
| 316 |
-
"data_type": "TEXT",
|
| 317 |
-
"semantic_type": "text",
|
| 318 |
-
"transformation": null,
|
| 319 |
-
"join_requirements": []
|
| 320 |
-
},
|
| 321 |
-
"customer_e": {
|
| 322 |
-
"customer_id": "customer_e",
|
| 323 |
-
"table_name": "contracts",
|
| 324 |
-
"column_name": "industry",
|
| 325 |
-
"data_type": "TEXT",
|
| 326 |
-
"semantic_type": "text",
|
| 327 |
-
"transformation": null,
|
| 328 |
-
"join_requirements": []
|
| 329 |
-
},
|
| 330 |
-
"customer_f": {
|
| 331 |
-
"customer_id": "customer_f",
|
| 332 |
-
"table_name": "contracts",
|
| 333 |
-
"column_name": "sector",
|
| 334 |
-
"data_type": "TEXT",
|
| 335 |
-
"semantic_type": "text",
|
| 336 |
-
"transformation": null,
|
| 337 |
-
"join_requirements": []
|
| 338 |
-
}
|
| 339 |
-
}
|
| 340 |
-
},
|
| 341 |
-
"customer_name": {
|
| 342 |
-
"concept_id": "customer_name",
|
| 343 |
-
"concept_name": "Customer Name",
|
| 344 |
-
"description": "Name of the customer/client organization",
|
| 345 |
-
"aliases": [
|
| 346 |
-
"client",
|
| 347 |
-
"account",
|
| 348 |
-
"organization",
|
| 349 |
-
"client_name"
|
| 350 |
-
],
|
| 351 |
-
"customer_mappings": {
|
| 352 |
-
"customer_a": {
|
| 353 |
-
"customer_id": "customer_a",
|
| 354 |
-
"table_name": "contracts",
|
| 355 |
-
"column_name": "customer_name",
|
| 356 |
-
"data_type": "TEXT",
|
| 357 |
-
"semantic_type": "text",
|
| 358 |
-
"transformation": null,
|
| 359 |
-
"join_requirements": []
|
| 360 |
-
},
|
| 361 |
-
"customer_b": {
|
| 362 |
-
"customer_id": "customer_b",
|
| 363 |
-
"table_name": "contract_headers",
|
| 364 |
-
"column_name": "client_name",
|
| 365 |
-
"data_type": "TEXT",
|
| 366 |
-
"semantic_type": "text",
|
| 367 |
-
"transformation": null,
|
| 368 |
-
"join_requirements": []
|
| 369 |
-
},
|
| 370 |
-
"customer_c": {
|
| 371 |
-
"customer_id": "customer_c",
|
| 372 |
-
"table_name": "contracts",
|
| 373 |
-
"column_name": "account",
|
| 374 |
-
"data_type": "TEXT",
|
| 375 |
-
"semantic_type": "text",
|
| 376 |
-
"transformation": null,
|
| 377 |
-
"join_requirements": []
|
| 378 |
-
},
|
| 379 |
-
"customer_d": {
|
| 380 |
-
"customer_id": "customer_d",
|
| 381 |
-
"table_name": "contracts",
|
| 382 |
-
"column_name": "customer_org",
|
| 383 |
-
"data_type": "TEXT",
|
| 384 |
-
"semantic_type": "text",
|
| 385 |
-
"transformation": null,
|
| 386 |
-
"join_requirements": []
|
| 387 |
-
},
|
| 388 |
-
"customer_e": {
|
| 389 |
-
"customer_id": "customer_e",
|
| 390 |
-
"table_name": "contracts",
|
| 391 |
-
"column_name": "customer_name",
|
| 392 |
-
"data_type": "TEXT",
|
| 393 |
-
"semantic_type": "text",
|
| 394 |
-
"transformation": null,
|
| 395 |
-
"join_requirements": []
|
| 396 |
-
},
|
| 397 |
-
"customer_f": {
|
| 398 |
-
"customer_id": "customer_f",
|
| 399 |
-
"table_name": "contracts",
|
| 400 |
-
"column_name": "account",
|
| 401 |
-
"data_type": "TEXT",
|
| 402 |
-
"semantic_type": "text",
|
| 403 |
-
"transformation": null,
|
| 404 |
-
"join_requirements": []
|
| 405 |
-
}
|
| 406 |
-
}
|
| 407 |
-
},
|
| 408 |
"contract_start": {
|
| 409 |
"concept_id": "contract_start",
|
| 410 |
"concept_name": "Contract Start Date",
|
|
@@ -471,73 +337,6 @@
|
|
| 471 |
"join_requirements": []
|
| 472 |
}
|
| 473 |
}
|
| 474 |
-
},
|
| 475 |
-
"contract_name": {
|
| 476 |
-
"concept_id": "contract_name",
|
| 477 |
-
"concept_name": "Contract Name",
|
| 478 |
-
"description": "The name or title of the contract",
|
| 479 |
-
"aliases": [
|
| 480 |
-
"name",
|
| 481 |
-
"title",
|
| 482 |
-
"contract title",
|
| 483 |
-
"agreement name"
|
| 484 |
-
],
|
| 485 |
-
"customer_mappings": {
|
| 486 |
-
"customer_a": {
|
| 487 |
-
"customer_id": "customer_a",
|
| 488 |
-
"table_name": "contracts",
|
| 489 |
-
"column_name": "contract_name",
|
| 490 |
-
"data_type": "TEXT",
|
| 491 |
-
"semantic_type": "text",
|
| 492 |
-
"transformation": null,
|
| 493 |
-
"join_requirements": []
|
| 494 |
-
},
|
| 495 |
-
"customer_b": {
|
| 496 |
-
"customer_id": "customer_b",
|
| 497 |
-
"table_name": "contract_headers",
|
| 498 |
-
"column_name": "name",
|
| 499 |
-
"data_type": "TEXT",
|
| 500 |
-
"semantic_type": "text",
|
| 501 |
-
"transformation": null,
|
| 502 |
-
"join_requirements": []
|
| 503 |
-
},
|
| 504 |
-
"customer_c": {
|
| 505 |
-
"customer_id": "customer_c",
|
| 506 |
-
"table_name": "contracts",
|
| 507 |
-
"column_name": "name",
|
| 508 |
-
"data_type": "TEXT",
|
| 509 |
-
"semantic_type": "text",
|
| 510 |
-
"transformation": null,
|
| 511 |
-
"join_requirements": []
|
| 512 |
-
},
|
| 513 |
-
"customer_d": {
|
| 514 |
-
"customer_id": "customer_d",
|
| 515 |
-
"table_name": "contracts",
|
| 516 |
-
"column_name": "name",
|
| 517 |
-
"data_type": "TEXT",
|
| 518 |
-
"semantic_type": "text",
|
| 519 |
-
"transformation": null,
|
| 520 |
-
"join_requirements": []
|
| 521 |
-
},
|
| 522 |
-
"customer_e": {
|
| 523 |
-
"customer_id": "customer_e",
|
| 524 |
-
"table_name": "contracts",
|
| 525 |
-
"column_name": "contract_title",
|
| 526 |
-
"data_type": "TEXT",
|
| 527 |
-
"semantic_type": "text",
|
| 528 |
-
"transformation": null,
|
| 529 |
-
"join_requirements": []
|
| 530 |
-
},
|
| 531 |
-
"customer_f": {
|
| 532 |
-
"customer_id": "customer_f",
|
| 533 |
-
"table_name": "contracts",
|
| 534 |
-
"column_name": "agreement_name",
|
| 535 |
-
"data_type": "TEXT",
|
| 536 |
-
"semantic_type": "text",
|
| 537 |
-
"transformation": null,
|
| 538 |
-
"join_requirements": []
|
| 539 |
-
}
|
| 540 |
-
}
|
| 541 |
}
|
| 542 |
},
|
| 543 |
"transformations": {
|
|
|
|
| 271 |
}
|
| 272 |
}
|
| 273 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
"contract_start": {
|
| 275 |
"concept_id": "contract_start",
|
| 276 |
"concept_name": "Contract Start Date",
|
|
|
|
| 337 |
"join_requirements": []
|
| 338 |
}
|
| 339 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
}
|
| 341 |
},
|
| 342 |
"transformations": {
|
|
@@ -43,6 +43,18 @@ class QueryUnderstandingAgent:
|
|
| 43 |
Available Semantic Concepts:
|
| 44 |
{concept_list}
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
Your task is to:
|
| 47 |
1. Identify the user's query intent (find_contracts, count_contracts, aggregate_values)
|
| 48 |
2. Determine which semantic concepts are relevant
|
|
@@ -88,7 +100,8 @@ Return ONLY valid JSON matching this schema:
|
|
| 88 |
"function": "count" | "sum" | "average",
|
| 89 |
"concept": "concept_name"
|
| 90 |
}}], // optional, for aggregate_values intent
|
| 91 |
-
"limit": 10 // optional, for find_contracts intent
|
|
|
|
| 92 |
}}
|
| 93 |
|
| 94 |
IMPORTANT:
|
|
@@ -110,7 +123,8 @@ Result:
|
|
| 110 |
"intent": "find_contracts",
|
| 111 |
"filters": [],
|
| 112 |
"projections": [], // Empty projections means return ALL available fields
|
| 113 |
-
"limit": 100
|
|
|
|
| 114 |
}}
|
| 115 |
|
| 116 |
Query: "Show me all active contracts"
|
|
@@ -125,7 +139,18 @@ Result:
|
|
| 125 |
}}
|
| 126 |
],
|
| 127 |
"projections": [], // Empty projections means return ALL available fields
|
| 128 |
-
"limit": 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
}}
|
| 130 |
|
| 131 |
Query: "How many contracts expire in the next 90 days?"
|
|
@@ -156,26 +181,21 @@ Result:
|
|
| 156 |
"value": 2000000
|
| 157 |
}}
|
| 158 |
],
|
| 159 |
-
"projections": ["contract_identifier", "contract_value", "
|
| 160 |
}}
|
| 161 |
|
| 162 |
-
Query: "List
|
| 163 |
Result:
|
| 164 |
{{
|
| 165 |
"intent": "find_contracts",
|
| 166 |
"filters": [
|
| 167 |
-
{{
|
| 168 |
-
"concept": "industry_sector",
|
| 169 |
-
"operator": "equals",
|
| 170 |
-
"value": "Technology"
|
| 171 |
-
}},
|
| 172 |
{{
|
| 173 |
"concept": "contract_expiration",
|
| 174 |
"operator": "between",
|
| 175 |
"value": ["2026-01-01", "2026-12-31"]
|
| 176 |
}}
|
| 177 |
],
|
| 178 |
-
"projections": ["contract_identifier", "
|
| 179 |
"limit": 50
|
| 180 |
}}
|
| 181 |
|
|
|
|
| 43 |
Available Semantic Concepts:
|
| 44 |
{concept_list}
|
| 45 |
|
| 46 |
+
CRITICAL RULE - Database Selection:
|
| 47 |
+
When users mention "customer" or database references, extract them into target_customers field:
|
| 48 |
+
- "customer_a", "customer_b", "customer_c", "customer_d", "customer_e", "customer_f"
|
| 49 |
+
- "customer A", "customer B", "database A", "from customer_a", "customer a database"
|
| 50 |
+
- Any reference to which specific database(s) to query
|
| 51 |
+
|
| 52 |
+
DO NOT create semantic filters for these - use the target_customers field instead.
|
| 53 |
+
Examples:
|
| 54 |
+
- "Show contracts from customer_a" → target_customers: ["customer_a"]
|
| 55 |
+
- "Query customer A and B" → target_customers: ["customer_a", "customer_b"]
|
| 56 |
+
- "Show all contracts" → target_customers: null (means query all)
|
| 57 |
+
|
| 58 |
Your task is to:
|
| 59 |
1. Identify the user's query intent (find_contracts, count_contracts, aggregate_values)
|
| 60 |
2. Determine which semantic concepts are relevant
|
|
|
|
| 100 |
"function": "count" | "sum" | "average",
|
| 101 |
"concept": "concept_name"
|
| 102 |
}}], // optional, for aggregate_values intent
|
| 103 |
+
"limit": 10, // optional, for find_contracts intent
|
| 104 |
+
"target_customers": ["customer_a", "customer_b"] // optional, null or omit for all customers
|
| 105 |
}}
|
| 106 |
|
| 107 |
IMPORTANT:
|
|
|
|
| 123 |
"intent": "find_contracts",
|
| 124 |
"filters": [],
|
| 125 |
"projections": [], // Empty projections means return ALL available fields
|
| 126 |
+
"limit": 100,
|
| 127 |
+
"target_customers": null // null means query all customers
|
| 128 |
}}
|
| 129 |
|
| 130 |
Query: "Show me all active contracts"
|
|
|
|
| 139 |
}}
|
| 140 |
],
|
| 141 |
"projections": [], // Empty projections means return ALL available fields
|
| 142 |
+
"limit": 100,
|
| 143 |
+
"target_customers": null
|
| 144 |
+
}}
|
| 145 |
+
|
| 146 |
+
Query: "Show contracts from customer_a" or "Query customer A database"
|
| 147 |
+
Result:
|
| 148 |
+
{{
|
| 149 |
+
"intent": "find_contracts",
|
| 150 |
+
"filters": [],
|
| 151 |
+
"projections": [],
|
| 152 |
+
"limit": 100,
|
| 153 |
+
"target_customers": ["customer_a"] // Extract database reference
|
| 154 |
}}
|
| 155 |
|
| 156 |
Query: "How many contracts expire in the next 90 days?"
|
|
|
|
| 181 |
"value": 2000000
|
| 182 |
}}
|
| 183 |
],
|
| 184 |
+
"projections": ["contract_identifier", "contract_value", "contract_status"]
|
| 185 |
}}
|
| 186 |
|
| 187 |
+
Query: "List contracts expiring in 2026"
|
| 188 |
Result:
|
| 189 |
{{
|
| 190 |
"intent": "find_contracts",
|
| 191 |
"filters": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
{{
|
| 193 |
"concept": "contract_expiration",
|
| 194 |
"operator": "between",
|
| 195 |
"value": ["2026-01-01", "2026-12-31"]
|
| 196 |
}}
|
| 197 |
],
|
| 198 |
+
"projections": ["contract_identifier", "contract_expiration", "contract_value"],
|
| 199 |
"limit": 50
|
| 200 |
}}
|
| 201 |
|
|
@@ -108,22 +108,18 @@ class MockDataGenerator:
|
|
| 108 |
conn = sqlite3.connect(db_path)
|
| 109 |
cursor = conn.cursor()
|
| 110 |
|
| 111 |
-
# Create table
|
| 112 |
cursor.execute("""
|
| 113 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 114 |
contract_id INTEGER PRIMARY KEY,
|
| 115 |
-
contract_name TEXT NOT NULL,
|
| 116 |
-
customer_name TEXT NOT NULL,
|
| 117 |
contract_value INTEGER NOT NULL,
|
| 118 |
status TEXT NOT NULL,
|
| 119 |
expiry_date TEXT NOT NULL,
|
| 120 |
-
start_date TEXT NOT NULL
|
| 121 |
-
industry TEXT NOT NULL
|
| 122 |
)
|
| 123 |
""")
|
| 124 |
|
| 125 |
# Generate data
|
| 126 |
-
industries = self.INDUSTRIES["customer_a"]
|
| 127 |
for i in range(1, 51):
|
| 128 |
start_date, expiry_date = self.generate_dates()
|
| 129 |
|
|
@@ -135,18 +131,14 @@ class MockDataGenerator:
|
|
| 135 |
|
| 136 |
cursor.execute("""
|
| 137 |
INSERT INTO contracts
|
| 138 |
-
(contract_id,
|
| 139 |
-
|
| 140 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 141 |
""", (
|
| 142 |
i,
|
| 143 |
-
self.generate_contract_name(i),
|
| 144 |
-
self.generate_company_name(),
|
| 145 |
self.generate_contract_value(is_annual=False),
|
| 146 |
status,
|
| 147 |
expiry_date.strftime("%Y-%m-%d"),
|
| 148 |
-
start_date.strftime("%Y-%m-%d")
|
| 149 |
-
random.choice(industries)
|
| 150 |
))
|
| 151 |
|
| 152 |
conn.commit()
|
|
@@ -159,15 +151,12 @@ class MockDataGenerator:
|
|
| 159 |
conn = sqlite3.connect(db_path)
|
| 160 |
cursor = conn.cursor()
|
| 161 |
|
| 162 |
-
# Create tables
|
| 163 |
cursor.execute("""
|
| 164 |
CREATE TABLE IF NOT EXISTS contract_headers (
|
| 165 |
id INTEGER PRIMARY KEY,
|
| 166 |
-
contract_name TEXT NOT NULL,
|
| 167 |
-
client_name TEXT NOT NULL,
|
| 168 |
contract_value INTEGER NOT NULL,
|
| 169 |
-
start_date TEXT NOT NULL
|
| 170 |
-
sector TEXT NOT NULL
|
| 171 |
)
|
| 172 |
""")
|
| 173 |
|
|
@@ -191,22 +180,18 @@ class MockDataGenerator:
|
|
| 191 |
""")
|
| 192 |
|
| 193 |
# Generate data
|
| 194 |
-
industries = self.INDUSTRIES["customer_b"]
|
| 195 |
for i in range(1, 51):
|
| 196 |
start_date, expiry_date = self.generate_dates()
|
| 197 |
|
| 198 |
# Insert header
|
| 199 |
cursor.execute("""
|
| 200 |
INSERT INTO contract_headers
|
| 201 |
-
(id,
|
| 202 |
-
VALUES (?, ?,
|
| 203 |
""", (
|
| 204 |
i,
|
| 205 |
-
self.generate_contract_name(i),
|
| 206 |
-
self.generate_company_name(),
|
| 207 |
self.generate_contract_value(is_annual=False),
|
| 208 |
-
start_date.strftime("%Y-%m-%d")
|
| 209 |
-
random.choice(industries)
|
| 210 |
))
|
| 211 |
|
| 212 |
# Insert status history (1-3 status changes)
|
|
@@ -244,22 +229,18 @@ class MockDataGenerator:
|
|
| 244 |
conn = sqlite3.connect(db_path)
|
| 245 |
cursor = conn.cursor()
|
| 246 |
|
| 247 |
-
# Create table with different column names
|
| 248 |
cursor.execute("""
|
| 249 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 250 |
id INTEGER PRIMARY KEY,
|
| 251 |
-
name TEXT NOT NULL,
|
| 252 |
-
account TEXT NOT NULL,
|
| 253 |
total_value INTEGER NOT NULL,
|
| 254 |
current_status TEXT NOT NULL,
|
| 255 |
expiration_date TEXT NOT NULL,
|
| 256 |
-
inception_date TEXT NOT NULL
|
| 257 |
-
business_sector TEXT NOT NULL
|
| 258 |
)
|
| 259 |
""")
|
| 260 |
|
| 261 |
# Generate data
|
| 262 |
-
industries = self.INDUSTRIES["customer_c"]
|
| 263 |
for i in range(1, 51):
|
| 264 |
start_date, expiry_date = self.generate_dates()
|
| 265 |
|
|
@@ -270,18 +251,14 @@ class MockDataGenerator:
|
|
| 270 |
|
| 271 |
cursor.execute("""
|
| 272 |
INSERT INTO contracts
|
| 273 |
-
(id,
|
| 274 |
-
|
| 275 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 276 |
""", (
|
| 277 |
i,
|
| 278 |
-
self.generate_contract_name(i),
|
| 279 |
-
self.generate_company_name(),
|
| 280 |
self.generate_contract_value(is_annual=False),
|
| 281 |
status,
|
| 282 |
expiry_date.strftime("%Y-%m-%d"),
|
| 283 |
-
start_date.strftime("%Y-%m-%d")
|
| 284 |
-
random.choice(industries)
|
| 285 |
))
|
| 286 |
|
| 287 |
conn.commit()
|
|
@@ -294,22 +271,18 @@ class MockDataGenerator:
|
|
| 294 |
conn = sqlite3.connect(db_path)
|
| 295 |
cursor = conn.cursor()
|
| 296 |
|
| 297 |
-
# Create table with days_remaining instead of date
|
| 298 |
cursor.execute("""
|
| 299 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 300 |
contract_id INTEGER PRIMARY KEY,
|
| 301 |
-
contract_title TEXT NOT NULL,
|
| 302 |
-
customer_org TEXT NOT NULL,
|
| 303 |
contract_value INTEGER NOT NULL,
|
| 304 |
status TEXT NOT NULL,
|
| 305 |
days_remaining INTEGER NOT NULL,
|
| 306 |
-
start_date TEXT NOT NULL
|
| 307 |
-
industry TEXT NOT NULL
|
| 308 |
)
|
| 309 |
""")
|
| 310 |
|
| 311 |
# Generate data
|
| 312 |
-
industries = self.INDUSTRIES["customer_d"]
|
| 313 |
for i in range(1, 51):
|
| 314 |
start_date, expiry_date = self.generate_dates()
|
| 315 |
|
|
@@ -323,18 +296,14 @@ class MockDataGenerator:
|
|
| 323 |
|
| 324 |
cursor.execute("""
|
| 325 |
INSERT INTO contracts
|
| 326 |
-
(contract_id,
|
| 327 |
-
|
| 328 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
| 329 |
""", (
|
| 330 |
i,
|
| 331 |
-
self.generate_contract_name(i),
|
| 332 |
-
self.generate_company_name(),
|
| 333 |
self.generate_contract_value(is_annual=False),
|
| 334 |
status,
|
| 335 |
days_remaining,
|
| 336 |
-
start_date.strftime("%Y-%m-%d")
|
| 337 |
-
random.choice(industries)
|
| 338 |
))
|
| 339 |
|
| 340 |
conn.commit()
|
|
@@ -347,23 +316,19 @@ class MockDataGenerator:
|
|
| 347 |
conn = sqlite3.connect(db_path)
|
| 348 |
cursor = conn.cursor()
|
| 349 |
|
| 350 |
-
# Create table
|
| 351 |
cursor.execute("""
|
| 352 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 353 |
contract_id INTEGER PRIMARY KEY,
|
| 354 |
-
contract_name TEXT NOT NULL,
|
| 355 |
-
customer_name TEXT NOT NULL,
|
| 356 |
contract_value INTEGER NOT NULL,
|
| 357 |
term_years REAL NOT NULL,
|
| 358 |
status TEXT NOT NULL,
|
| 359 |
expiry_date TEXT NOT NULL,
|
| 360 |
-
start_date TEXT NOT NULL
|
| 361 |
-
industry TEXT NOT NULL
|
| 362 |
)
|
| 363 |
""")
|
| 364 |
|
| 365 |
# Generate data
|
| 366 |
-
industries = self.INDUSTRIES["customer_e"]
|
| 367 |
for i in range(1, 51):
|
| 368 |
start_date, expiry_date = self.generate_dates()
|
| 369 |
|
|
@@ -378,19 +343,15 @@ class MockDataGenerator:
|
|
| 378 |
|
| 379 |
cursor.execute("""
|
| 380 |
INSERT INTO contracts
|
| 381 |
-
(contract_id,
|
| 382 |
-
|
| 383 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 384 |
""", (
|
| 385 |
i,
|
| 386 |
-
self.generate_contract_name(i),
|
| 387 |
-
self.generate_company_name(),
|
| 388 |
self.generate_contract_value(is_annual=False),
|
| 389 |
term_years,
|
| 390 |
status,
|
| 391 |
expiry_date.strftime("%Y-%m-%d"),
|
| 392 |
-
start_date.strftime("%Y-%m-%d")
|
| 393 |
-
random.choice(industries)
|
| 394 |
))
|
| 395 |
|
| 396 |
conn.commit()
|
|
@@ -403,23 +364,19 @@ class MockDataGenerator:
|
|
| 403 |
conn = sqlite3.connect(db_path)
|
| 404 |
cursor = conn.cursor()
|
| 405 |
|
| 406 |
-
# Create table
|
| 407 |
cursor.execute("""
|
| 408 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 409 |
contract_id INTEGER PRIMARY KEY,
|
| 410 |
-
name TEXT NOT NULL,
|
| 411 |
-
account TEXT NOT NULL,
|
| 412 |
contract_value INTEGER NOT NULL,
|
| 413 |
term_years REAL NOT NULL,
|
| 414 |
status TEXT NOT NULL,
|
| 415 |
expiration_date TEXT NOT NULL,
|
| 416 |
-
start_date TEXT NOT NULL
|
| 417 |
-
sector TEXT NOT NULL
|
| 418 |
)
|
| 419 |
""")
|
| 420 |
|
| 421 |
# Generate data
|
| 422 |
-
industries = self.INDUSTRIES["customer_f"]
|
| 423 |
for i in range(1, 51):
|
| 424 |
start_date, expiry_date = self.generate_dates()
|
| 425 |
|
|
@@ -434,19 +391,15 @@ class MockDataGenerator:
|
|
| 434 |
|
| 435 |
cursor.execute("""
|
| 436 |
INSERT INTO contracts
|
| 437 |
-
(contract_id,
|
| 438 |
-
|
| 439 |
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
| 440 |
""", (
|
| 441 |
i,
|
| 442 |
-
self.generate_contract_name(i),
|
| 443 |
-
self.generate_company_name(),
|
| 444 |
self.generate_contract_value(is_annual=True), # ANNUAL value
|
| 445 |
term_years,
|
| 446 |
status,
|
| 447 |
expiry_date.strftime("%Y-%m-%d"),
|
| 448 |
-
start_date.strftime("%Y-%m-%d")
|
| 449 |
-
random.choice(industries)
|
| 450 |
))
|
| 451 |
|
| 452 |
conn.commit()
|
|
|
|
| 108 |
conn = sqlite3.connect(db_path)
|
| 109 |
cursor = conn.cursor()
|
| 110 |
|
| 111 |
+
# Create table (removed customer_name and industry - not in knowledge graph)
|
| 112 |
cursor.execute("""
|
| 113 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 114 |
contract_id INTEGER PRIMARY KEY,
|
|
|
|
|
|
|
| 115 |
contract_value INTEGER NOT NULL,
|
| 116 |
status TEXT NOT NULL,
|
| 117 |
expiry_date TEXT NOT NULL,
|
| 118 |
+
start_date TEXT NOT NULL
|
|
|
|
| 119 |
)
|
| 120 |
""")
|
| 121 |
|
| 122 |
# Generate data
|
|
|
|
| 123 |
for i in range(1, 51):
|
| 124 |
start_date, expiry_date = self.generate_dates()
|
| 125 |
|
|
|
|
| 131 |
|
| 132 |
cursor.execute("""
|
| 133 |
INSERT INTO contracts
|
| 134 |
+
(contract_id, contract_value, status, expiry_date, start_date)
|
| 135 |
+
VALUES (?, ?, ?, ?, ?)
|
|
|
|
| 136 |
""", (
|
| 137 |
i,
|
|
|
|
|
|
|
| 138 |
self.generate_contract_value(is_annual=False),
|
| 139 |
status,
|
| 140 |
expiry_date.strftime("%Y-%m-%d"),
|
| 141 |
+
start_date.strftime("%Y-%m-%d")
|
|
|
|
| 142 |
))
|
| 143 |
|
| 144 |
conn.commit()
|
|
|
|
| 151 |
conn = sqlite3.connect(db_path)
|
| 152 |
cursor = conn.cursor()
|
| 153 |
|
| 154 |
+
# Create tables (removed client_name and sector)
|
| 155 |
cursor.execute("""
|
| 156 |
CREATE TABLE IF NOT EXISTS contract_headers (
|
| 157 |
id INTEGER PRIMARY KEY,
|
|
|
|
|
|
|
| 158 |
contract_value INTEGER NOT NULL,
|
| 159 |
+
start_date TEXT NOT NULL
|
|
|
|
| 160 |
)
|
| 161 |
""")
|
| 162 |
|
|
|
|
| 180 |
""")
|
| 181 |
|
| 182 |
# Generate data
|
|
|
|
| 183 |
for i in range(1, 51):
|
| 184 |
start_date, expiry_date = self.generate_dates()
|
| 185 |
|
| 186 |
# Insert header
|
| 187 |
cursor.execute("""
|
| 188 |
INSERT INTO contract_headers
|
| 189 |
+
(id, contract_value, start_date)
|
| 190 |
+
VALUES (?, ?, ?)
|
| 191 |
""", (
|
| 192 |
i,
|
|
|
|
|
|
|
| 193 |
self.generate_contract_value(is_annual=False),
|
| 194 |
+
start_date.strftime("%Y-%m-%d")
|
|
|
|
| 195 |
))
|
| 196 |
|
| 197 |
# Insert status history (1-3 status changes)
|
|
|
|
| 229 |
conn = sqlite3.connect(db_path)
|
| 230 |
cursor = conn.cursor()
|
| 231 |
|
| 232 |
+
# Create table with different column names (removed account and business_sector)
|
| 233 |
cursor.execute("""
|
| 234 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 235 |
id INTEGER PRIMARY KEY,
|
|
|
|
|
|
|
| 236 |
total_value INTEGER NOT NULL,
|
| 237 |
current_status TEXT NOT NULL,
|
| 238 |
expiration_date TEXT NOT NULL,
|
| 239 |
+
inception_date TEXT NOT NULL
|
|
|
|
| 240 |
)
|
| 241 |
""")
|
| 242 |
|
| 243 |
# Generate data
|
|
|
|
| 244 |
for i in range(1, 51):
|
| 245 |
start_date, expiry_date = self.generate_dates()
|
| 246 |
|
|
|
|
| 251 |
|
| 252 |
cursor.execute("""
|
| 253 |
INSERT INTO contracts
|
| 254 |
+
(id, total_value, current_status, expiration_date, inception_date)
|
| 255 |
+
VALUES (?, ?, ?, ?, ?)
|
|
|
|
| 256 |
""", (
|
| 257 |
i,
|
|
|
|
|
|
|
| 258 |
self.generate_contract_value(is_annual=False),
|
| 259 |
status,
|
| 260 |
expiry_date.strftime("%Y-%m-%d"),
|
| 261 |
+
start_date.strftime("%Y-%m-%d")
|
|
|
|
| 262 |
))
|
| 263 |
|
| 264 |
conn.commit()
|
|
|
|
| 271 |
conn = sqlite3.connect(db_path)
|
| 272 |
cursor = conn.cursor()
|
| 273 |
|
| 274 |
+
# Create table with days_remaining instead of date (removed customer_org and industry)
|
| 275 |
cursor.execute("""
|
| 276 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 277 |
contract_id INTEGER PRIMARY KEY,
|
|
|
|
|
|
|
| 278 |
contract_value INTEGER NOT NULL,
|
| 279 |
status TEXT NOT NULL,
|
| 280 |
days_remaining INTEGER NOT NULL,
|
| 281 |
+
start_date TEXT NOT NULL
|
|
|
|
| 282 |
)
|
| 283 |
""")
|
| 284 |
|
| 285 |
# Generate data
|
|
|
|
| 286 |
for i in range(1, 51):
|
| 287 |
start_date, expiry_date = self.generate_dates()
|
| 288 |
|
|
|
|
| 296 |
|
| 297 |
cursor.execute("""
|
| 298 |
INSERT INTO contracts
|
| 299 |
+
(contract_id, contract_value, status, days_remaining, start_date)
|
| 300 |
+
VALUES (?, ?, ?, ?, ?)
|
|
|
|
| 301 |
""", (
|
| 302 |
i,
|
|
|
|
|
|
|
| 303 |
self.generate_contract_value(is_annual=False),
|
| 304 |
status,
|
| 305 |
days_remaining,
|
| 306 |
+
start_date.strftime("%Y-%m-%d")
|
|
|
|
| 307 |
))
|
| 308 |
|
| 309 |
conn.commit()
|
|
|
|
| 316 |
conn = sqlite3.connect(db_path)
|
| 317 |
cursor = conn.cursor()
|
| 318 |
|
| 319 |
+
# Create table (removed customer_name and industry)
|
| 320 |
cursor.execute("""
|
| 321 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 322 |
contract_id INTEGER PRIMARY KEY,
|
|
|
|
|
|
|
| 323 |
contract_value INTEGER NOT NULL,
|
| 324 |
term_years REAL NOT NULL,
|
| 325 |
status TEXT NOT NULL,
|
| 326 |
expiry_date TEXT NOT NULL,
|
| 327 |
+
start_date TEXT NOT NULL
|
|
|
|
| 328 |
)
|
| 329 |
""")
|
| 330 |
|
| 331 |
# Generate data
|
|
|
|
| 332 |
for i in range(1, 51):
|
| 333 |
start_date, expiry_date = self.generate_dates()
|
| 334 |
|
|
|
|
| 343 |
|
| 344 |
cursor.execute("""
|
| 345 |
INSERT INTO contracts
|
| 346 |
+
(contract_id, contract_value, term_years, status, expiry_date, start_date)
|
| 347 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
|
|
| 348 |
""", (
|
| 349 |
i,
|
|
|
|
|
|
|
| 350 |
self.generate_contract_value(is_annual=False),
|
| 351 |
term_years,
|
| 352 |
status,
|
| 353 |
expiry_date.strftime("%Y-%m-%d"),
|
| 354 |
+
start_date.strftime("%Y-%m-%d")
|
|
|
|
| 355 |
))
|
| 356 |
|
| 357 |
conn.commit()
|
|
|
|
| 364 |
conn = sqlite3.connect(db_path)
|
| 365 |
cursor = conn.cursor()
|
| 366 |
|
| 367 |
+
# Create table (removed account and sector)
|
| 368 |
cursor.execute("""
|
| 369 |
CREATE TABLE IF NOT EXISTS contracts (
|
| 370 |
contract_id INTEGER PRIMARY KEY,
|
|
|
|
|
|
|
| 371 |
contract_value INTEGER NOT NULL,
|
| 372 |
term_years REAL NOT NULL,
|
| 373 |
status TEXT NOT NULL,
|
| 374 |
expiration_date TEXT NOT NULL,
|
| 375 |
+
start_date TEXT NOT NULL
|
|
|
|
| 376 |
)
|
| 377 |
""")
|
| 378 |
|
| 379 |
# Generate data
|
|
|
|
| 380 |
for i in range(1, 51):
|
| 381 |
start_date, expiry_date = self.generate_dates()
|
| 382 |
|
|
|
|
| 391 |
|
| 392 |
cursor.execute("""
|
| 393 |
INSERT INTO contracts
|
| 394 |
+
(contract_id, contract_value, term_years, status, expiration_date, start_date)
|
| 395 |
+
VALUES (?, ?, ?, ?, ?, ?)
|
|
|
|
| 396 |
""", (
|
| 397 |
i,
|
|
|
|
|
|
|
| 398 |
self.generate_contract_value(is_annual=True), # ANNUAL value
|
| 399 |
term_years,
|
| 400 |
status,
|
| 401 |
expiry_date.strftime("%Y-%m-%d"),
|
| 402 |
+
start_date.strftime("%Y-%m-%d")
|
|
|
|
| 403 |
))
|
| 404 |
|
| 405 |
conn.commit()
|
|
@@ -90,8 +90,8 @@ class SchemaTable(BaseModel):
|
|
| 90 |
|
| 91 |
class CustomerSchema(BaseModel):
|
| 92 |
"""Represents the complete schema for a customer database."""
|
| 93 |
-
customer_id: str = Field(..., description="Unique customer identifier")
|
| 94 |
-
customer_name: str = Field(
|
| 95 |
tables: List[SchemaTable] = Field(..., description="Tables in this schema")
|
| 96 |
semantic_notes: Dict[str, str] = Field(
|
| 97 |
default_factory=dict,
|
|
@@ -170,6 +170,10 @@ class SemanticQueryPlan(BaseModel):
|
|
| 170 |
description="Ordering (concept, direction) pairs"
|
| 171 |
)
|
| 172 |
limit: Optional[int] = Field(None, description="Maximum number of results")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
model_config = {"use_enum_values": True}
|
| 175 |
|
|
|
|
| 90 |
|
| 91 |
class CustomerSchema(BaseModel):
|
| 92 |
"""Represents the complete schema for a customer database."""
|
| 93 |
+
customer_id: str = Field(..., description="Unique customer identifier (e.g., customer_a)")
|
| 94 |
+
customer_name: Optional[str] = Field(None, description="Optional customer display name")
|
| 95 |
tables: List[SchemaTable] = Field(..., description="Tables in this schema")
|
| 96 |
semantic_notes: Dict[str, str] = Field(
|
| 97 |
default_factory=dict,
|
|
|
|
| 170 |
description="Ordering (concept, direction) pairs"
|
| 171 |
)
|
| 172 |
limit: Optional[int] = Field(None, description="Maximum number of results")
|
| 173 |
+
target_customers: Optional[List[str]] = Field(
|
| 174 |
+
None,
|
| 175 |
+
description="Specific customer databases to query (e.g., ['customer_a', 'customer_b']). None means all customers."
|
| 176 |
+
)
|
| 177 |
|
| 178 |
model_config = {"use_enum_values": True}
|
| 179 |
|
|
@@ -122,14 +122,21 @@ class ChatOrchestrator:
|
|
| 122 |
if debug:
|
| 123 |
logger.info(f"Semantic plan: {semantic_plan}")
|
| 124 |
|
| 125 |
-
# Step 3:
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
result = self.result_harmonizer.execute_across_customers(
|
| 128 |
semantic_plan,
|
| 129 |
-
customer_ids=
|
| 130 |
)
|
| 131 |
|
| 132 |
-
# Step
|
| 133 |
total_time_ms = (time.time() - start_time) * 1000
|
| 134 |
|
| 135 |
logger.info(
|
|
@@ -138,7 +145,7 @@ class ChatOrchestrator:
|
|
| 138 |
f"{total_time_ms:.2f}ms"
|
| 139 |
)
|
| 140 |
|
| 141 |
-
# Step
|
| 142 |
self._add_to_history(
|
| 143 |
query_text=query_text,
|
| 144 |
semantic_plan=semantic_plan,
|
|
@@ -147,7 +154,7 @@ class ChatOrchestrator:
|
|
| 147 |
error=None
|
| 148 |
)
|
| 149 |
|
| 150 |
-
# Step
|
| 151 |
response = {
|
| 152 |
"success": True,
|
| 153 |
"query_text": query_text,
|
|
|
|
| 122 |
if debug:
|
| 123 |
logger.info(f"Semantic plan: {semantic_plan}")
|
| 124 |
|
| 125 |
+
# Step 3: Determine which customers to query
|
| 126 |
+
# Priority: explicit customer_ids parameter > target_customers from query > all customers
|
| 127 |
+
target_customers = customer_ids
|
| 128 |
+
if target_customers is None and semantic_plan.target_customers:
|
| 129 |
+
target_customers = semantic_plan.target_customers
|
| 130 |
+
logger.info(f"Extracted target customers from query: {target_customers}")
|
| 131 |
+
|
| 132 |
+
# Step 4: Execute query across customers
|
| 133 |
+
logger.info(f"Executing query across {len(target_customers) if target_customers else 'all'} customers...")
|
| 134 |
result = self.result_harmonizer.execute_across_customers(
|
| 135 |
semantic_plan,
|
| 136 |
+
customer_ids=target_customers
|
| 137 |
)
|
| 138 |
|
| 139 |
+
# Step 5: Calculate total execution time
|
| 140 |
total_time_ms = (time.time() - start_time) * 1000
|
| 141 |
|
| 142 |
logger.info(
|
|
|
|
| 145 |
f"{total_time_ms:.2f}ms"
|
| 146 |
)
|
| 147 |
|
| 148 |
+
# Step 6: Add to history
|
| 149 |
self._add_to_history(
|
| 150 |
query_text=query_text,
|
| 151 |
semantic_plan=semantic_plan,
|
|
|
|
| 154 |
error=None
|
| 155 |
)
|
| 156 |
|
| 157 |
+
# Step 7: Build response
|
| 158 |
response = {
|
| 159 |
"success": True,
|
| 160 |
"query_text": query_text,
|
|
@@ -687,6 +687,7 @@ class ResultHarmonizer:
|
|
| 687 |
"""
|
| 688 |
harmonized = {}
|
| 689 |
|
|
|
|
| 690 |
for customer_field, concept_id in field_mappings.items():
|
| 691 |
if customer_field in row:
|
| 692 |
value = row[customer_field]
|
|
@@ -702,4 +703,12 @@ class ResultHarmonizer:
|
|
| 702 |
# Field not present in row
|
| 703 |
harmonized[concept_id] = None
|
| 704 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
return harmonized
|
|
|
|
| 687 |
"""
|
| 688 |
harmonized = {}
|
| 689 |
|
| 690 |
+
# First, map all fields that have concept mappings
|
| 691 |
for customer_field, concept_id in field_mappings.items():
|
| 692 |
if customer_field in row:
|
| 693 |
value = row[customer_field]
|
|
|
|
| 703 |
# Field not present in row
|
| 704 |
harmonized[concept_id] = None
|
| 705 |
|
| 706 |
+
# Second, include any unmapped columns (like aggregation results)
|
| 707 |
+
# These are columns like count_contract_identifier, sum_contract_value, etc.
|
| 708 |
+
for field_name, value in row.items():
|
| 709 |
+
if field_name not in field_mappings:
|
| 710 |
+
# This is an unmapped column (likely an aggregation)
|
| 711 |
+
# Include it as-is with its original name
|
| 712 |
+
harmonized[field_name] = value
|
| 713 |
+
|
| 714 |
return harmonized
|
|
@@ -98,11 +98,11 @@ class TestQueryUnderstandingAgent:
|
|
| 98 |
|
| 99 |
def test_multi_filter_query(self, query_agent):
|
| 100 |
"""Test parsing a query with multiple filters."""
|
| 101 |
-
query = "List
|
| 102 |
plan = query_agent.understand_query(query)
|
| 103 |
|
| 104 |
assert plan.intent == QueryIntent.FIND_CONTRACTS
|
| 105 |
-
assert len(plan.filters) >= 2
|
| 106 |
|
| 107 |
def test_explain_query_plan(self, query_agent):
|
| 108 |
"""Test generating human-readable explanations."""
|
|
@@ -122,7 +122,6 @@ class TestSchemaAnalyzerAgent:
|
|
| 122 |
# Create a simple test schema
|
| 123 |
schema = CustomerSchema(
|
| 124 |
customer_id="test_customer",
|
| 125 |
-
customer_name="Test Customer",
|
| 126 |
tables=[
|
| 127 |
SchemaTable(
|
| 128 |
name="contracts",
|
|
@@ -173,7 +172,6 @@ class TestSchemaAnalyzerAgent:
|
|
| 173 |
|
| 174 |
schema = CustomerSchema(
|
| 175 |
customer_id="test_customer",
|
| 176 |
-
customer_name="Test Customer",
|
| 177 |
tables=[
|
| 178 |
SchemaTable(
|
| 179 |
name="contracts",
|
|
@@ -301,7 +299,6 @@ class TestEndToEndAgent:
|
|
| 301 |
# Create a test schema
|
| 302 |
schema = CustomerSchema(
|
| 303 |
customer_id="test_new_customer",
|
| 304 |
-
customer_name="Test New Customer",
|
| 305 |
tables=[
|
| 306 |
SchemaTable(
|
| 307 |
name="agreements",
|
|
|
|
| 98 |
|
| 99 |
def test_multi_filter_query(self, query_agent):
|
| 100 |
"""Test parsing a query with multiple filters."""
|
| 101 |
+
query = "List active contracts worth more than 1 million expiring in 2026"
|
| 102 |
plan = query_agent.understand_query(query)
|
| 103 |
|
| 104 |
assert plan.intent == QueryIntent.FIND_CONTRACTS
|
| 105 |
+
assert len(plan.filters) >= 2 # Should have status + value + date filters
|
| 106 |
|
| 107 |
def test_explain_query_plan(self, query_agent):
|
| 108 |
"""Test generating human-readable explanations."""
|
|
|
|
| 122 |
# Create a simple test schema
|
| 123 |
schema = CustomerSchema(
|
| 124 |
customer_id="test_customer",
|
|
|
|
| 125 |
tables=[
|
| 126 |
SchemaTable(
|
| 127 |
name="contracts",
|
|
|
|
| 172 |
|
| 173 |
schema = CustomerSchema(
|
| 174 |
customer_id="test_customer",
|
|
|
|
| 175 |
tables=[
|
| 176 |
SchemaTable(
|
| 177 |
name="contracts",
|
|
|
|
| 299 |
# Create a test schema
|
| 300 |
schema = CustomerSchema(
|
| 301 |
customer_id="test_new_customer",
|
|
|
|
| 302 |
tables=[
|
| 303 |
SchemaTable(
|
| 304 |
name="agreements",
|
|
@@ -269,12 +269,13 @@ class TestLoadedKnowledgeGraph:
|
|
| 269 |
kg = SchemaKnowledgeGraph()
|
| 270 |
kg.load()
|
| 271 |
|
| 272 |
-
# Check concepts exist (
|
| 273 |
-
assert len(kg.concepts) ==
|
| 274 |
assert kg.get_concept("contract_identifier") is not None
|
| 275 |
-
assert kg.get_concept("contract_name") is not None
|
| 276 |
assert kg.get_concept("contract_value") is not None
|
| 277 |
assert kg.get_concept("contract_expiration") is not None
|
|
|
|
|
|
|
| 278 |
|
| 279 |
def test_all_customers_mapped(self):
|
| 280 |
"""Test that all 6 customers have mappings."""
|
|
|
|
| 269 |
kg = SchemaKnowledgeGraph()
|
| 270 |
kg.load()
|
| 271 |
|
| 272 |
+
# Check concepts exist (5 concepts, removed customer_name and industry_sector)
|
| 273 |
+
assert len(kg.concepts) == 5
|
| 274 |
assert kg.get_concept("contract_identifier") is not None
|
|
|
|
| 275 |
assert kg.get_concept("contract_value") is not None
|
| 276 |
assert kg.get_concept("contract_expiration") is not None
|
| 277 |
+
assert kg.get_concept("contract_status") is not None
|
| 278 |
+
assert kg.get_concept("contract_start") is not None
|
| 279 |
|
| 280 |
def test_all_customers_mapped(self):
|
| 281 |
"""Test that all 6 customers have mappings."""
|
|
@@ -89,7 +89,6 @@ class TestSchemaModels:
|
|
| 89 |
|
| 90 |
schema = CustomerSchema(
|
| 91 |
customer_id="customer_a",
|
| 92 |
-
customer_name="Customer A",
|
| 93 |
tables=[table],
|
| 94 |
semantic_notes={"contract_value": "lifetime total"}
|
| 95 |
)
|
|
|
|
| 89 |
|
| 90 |
schema = CustomerSchema(
|
| 91 |
customer_id="customer_a",
|
|
|
|
| 92 |
tables=[table],
|
| 93 |
semantic_notes={"contract_value": "lifetime total"}
|
| 94 |
)
|
|
@@ -151,16 +151,16 @@ class TestQueryCompiler:
|
|
| 151 |
|
| 152 |
assert "LIMIT 10" in sql
|
| 153 |
|
| 154 |
-
def
|
| 155 |
-
"""Test filtering by
|
| 156 |
plan = SemanticQueryPlan(
|
| 157 |
intent=QueryIntent.FIND_CONTRACTS,
|
| 158 |
-
projections=["contract_identifier", "
|
| 159 |
filters=[
|
| 160 |
QueryFilter(
|
| 161 |
-
concept="
|
| 162 |
operator=QueryOperator.EQUALS,
|
| 163 |
-
value="
|
| 164 |
)
|
| 165 |
]
|
| 166 |
)
|
|
@@ -168,8 +168,8 @@ class TestQueryCompiler:
|
|
| 168 |
sql = compiler.compile_for_customer(plan, "customer_a")
|
| 169 |
|
| 170 |
assert "WHERE" in sql
|
| 171 |
-
assert "
|
| 172 |
-
assert "
|
| 173 |
|
| 174 |
|
| 175 |
class TestDatabaseExecutor:
|
|
@@ -182,7 +182,7 @@ class TestDatabaseExecutor:
|
|
| 182 |
|
| 183 |
def test_simple_query(self, executor):
|
| 184 |
"""Test executing a simple query."""
|
| 185 |
-
sql = "SELECT contract_id,
|
| 186 |
result = executor.execute_query("customer_a", sql)
|
| 187 |
|
| 188 |
assert result.success
|
|
@@ -197,7 +197,7 @@ class TestDatabaseExecutor:
|
|
| 197 |
|
| 198 |
assert result.success
|
| 199 |
assert result.row_count == 1
|
| 200 |
-
assert result.data[0]["count"] ==
|
| 201 |
|
| 202 |
def test_invalid_sql(self, executor):
|
| 203 |
"""Test handling of invalid SQL."""
|
|
@@ -223,12 +223,12 @@ class TestDatabaseExecutor:
|
|
| 223 |
def test_count_rows(self, executor):
|
| 224 |
"""Test counting rows in a table."""
|
| 225 |
count = executor.count_rows("customer_a", "contracts")
|
| 226 |
-
assert count ==
|
| 227 |
|
| 228 |
def test_customer_b_multi_table(self, executor):
|
| 229 |
"""Test querying Customer B's multi-table schema."""
|
| 230 |
sql = """
|
| 231 |
-
SELECT h.id, h.
|
| 232 |
FROM contract_headers AS h
|
| 233 |
JOIN renewal_schedule AS r ON h.id = r.contract_id
|
| 234 |
LIMIT 5
|
|
@@ -262,7 +262,7 @@ class TestIntegration:
|
|
| 262 |
# Create query plan
|
| 263 |
plan = SemanticQueryPlan(
|
| 264 |
intent=QueryIntent.FIND_CONTRACTS,
|
| 265 |
-
projections=["contract_identifier", "contract_value", "
|
| 266 |
filters=[],
|
| 267 |
limit=10
|
| 268 |
)
|
|
@@ -355,7 +355,7 @@ class TestIntegration:
|
|
| 355 |
"""Test executing the same semantic query across all customers."""
|
| 356 |
plan = SemanticQueryPlan(
|
| 357 |
intent=QueryIntent.FIND_CONTRACTS,
|
| 358 |
-
projections=["contract_identifier", "contract_value", "
|
| 359 |
filters=[],
|
| 360 |
limit=5
|
| 361 |
)
|
|
|
|
| 151 |
|
| 152 |
assert "LIMIT 10" in sql
|
| 153 |
|
| 154 |
+
def test_status_filter(self, compiler):
|
| 155 |
+
"""Test filtering by contract status."""
|
| 156 |
plan = SemanticQueryPlan(
|
| 157 |
intent=QueryIntent.FIND_CONTRACTS,
|
| 158 |
+
projections=["contract_identifier", "contract_status"],
|
| 159 |
filters=[
|
| 160 |
QueryFilter(
|
| 161 |
+
concept="contract_status",
|
| 162 |
operator=QueryOperator.EQUALS,
|
| 163 |
+
value="active"
|
| 164 |
)
|
| 165 |
]
|
| 166 |
)
|
|
|
|
| 168 |
sql = compiler.compile_for_customer(plan, "customer_a")
|
| 169 |
|
| 170 |
assert "WHERE" in sql
|
| 171 |
+
assert "status" in sql
|
| 172 |
+
assert "active" in sql
|
| 173 |
|
| 174 |
|
| 175 |
class TestDatabaseExecutor:
|
|
|
|
| 182 |
|
| 183 |
def test_simple_query(self, executor):
|
| 184 |
"""Test executing a simple query."""
|
| 185 |
+
sql = "SELECT contract_id, contract_value FROM contracts LIMIT 5"
|
| 186 |
result = executor.execute_query("customer_a", sql)
|
| 187 |
|
| 188 |
assert result.success
|
|
|
|
| 197 |
|
| 198 |
assert result.success
|
| 199 |
assert result.row_count == 1
|
| 200 |
+
assert result.data[0]["count"] == 50 # 50 contracts per database
|
| 201 |
|
| 202 |
def test_invalid_sql(self, executor):
|
| 203 |
"""Test handling of invalid SQL."""
|
|
|
|
| 223 |
def test_count_rows(self, executor):
|
| 224 |
"""Test counting rows in a table."""
|
| 225 |
count = executor.count_rows("customer_a", "contracts")
|
| 226 |
+
assert count == 50 # 50 contracts per database
|
| 227 |
|
| 228 |
def test_customer_b_multi_table(self, executor):
|
| 229 |
"""Test querying Customer B's multi-table schema."""
|
| 230 |
sql = """
|
| 231 |
+
SELECT h.id, h.contract_value, r.renewal_date
|
| 232 |
FROM contract_headers AS h
|
| 233 |
JOIN renewal_schedule AS r ON h.id = r.contract_id
|
| 234 |
LIMIT 5
|
|
|
|
| 262 |
# Create query plan
|
| 263 |
plan = SemanticQueryPlan(
|
| 264 |
intent=QueryIntent.FIND_CONTRACTS,
|
| 265 |
+
projections=["contract_identifier", "contract_value", "contract_status"],
|
| 266 |
filters=[],
|
| 267 |
limit=10
|
| 268 |
)
|
|
|
|
| 355 |
"""Test executing the same semantic query across all customers."""
|
| 356 |
plan = SemanticQueryPlan(
|
| 357 |
intent=QueryIntent.FIND_CONTRACTS,
|
| 358 |
+
projections=["contract_identifier", "contract_value", "contract_status"],
|
| 359 |
filters=[],
|
| 360 |
limit=5
|
| 361 |
)
|
|
@@ -157,21 +157,21 @@ class TestValueHarmonizer:
|
|
| 157 |
assert harmonized["contract_status"] == "Active"
|
| 158 |
assert harmonized["contract_value"] == 100000.0
|
| 159 |
|
| 160 |
-
def
|
| 161 |
-
"""Test harmonizing a row with
|
| 162 |
row = {
|
| 163 |
"contract_id": "A001",
|
| 164 |
-
"
|
| 165 |
}
|
| 166 |
|
| 167 |
field_mappings = {
|
| 168 |
"contract_id": "contract_identifier",
|
| 169 |
-
"
|
| 170 |
}
|
| 171 |
|
| 172 |
harmonized = result_harmonizer._harmonize_row(row, "customer_a", field_mappings)
|
| 173 |
|
| 174 |
-
assert harmonized["
|
| 175 |
|
| 176 |
|
| 177 |
class TestResultHarmonizer:
|
|
@@ -393,20 +393,20 @@ class TestResultHarmonizer:
|
|
| 393 |
assert all(row.data["contract_status"] == "Active" for row in filtered_result.results)
|
| 394 |
|
| 395 |
def test_aggregate_results_count(self, result_harmonizer):
|
| 396 |
-
"""Test aggregating results with count."""
|
| 397 |
harmonized = HarmonizedResult(
|
| 398 |
results=[
|
| 399 |
HarmonizedRow(
|
| 400 |
customer_id="customer_a",
|
| 401 |
-
data={"
|
| 402 |
),
|
| 403 |
HarmonizedRow(
|
| 404 |
customer_id="customer_a",
|
| 405 |
-
data={"
|
| 406 |
),
|
| 407 |
HarmonizedRow(
|
| 408 |
customer_id="customer_a",
|
| 409 |
-
data={"
|
| 410 |
),
|
| 411 |
],
|
| 412 |
total_count=3,
|
|
@@ -417,19 +417,19 @@ class TestResultHarmonizer:
|
|
| 417 |
|
| 418 |
aggregated = result_harmonizer.aggregate_results(
|
| 419 |
harmonized,
|
| 420 |
-
group_by=["
|
| 421 |
aggregations={"contract_value": "count"}
|
| 422 |
)
|
| 423 |
|
| 424 |
-
assert aggregated.total_count == 2 # Two
|
| 425 |
|
| 426 |
-
# Find
|
| 427 |
-
|
| 428 |
-
(r for r in aggregated.results if r.data.get("
|
| 429 |
None
|
| 430 |
)
|
| 431 |
-
assert
|
| 432 |
-
assert
|
| 433 |
|
| 434 |
def test_aggregate_results_sum(self, result_harmonizer):
|
| 435 |
"""Test aggregating results with sum."""
|
|
@@ -437,11 +437,11 @@ class TestResultHarmonizer:
|
|
| 437 |
results=[
|
| 438 |
HarmonizedRow(
|
| 439 |
customer_id="customer_a",
|
| 440 |
-
data={"
|
| 441 |
),
|
| 442 |
HarmonizedRow(
|
| 443 |
customer_id="customer_a",
|
| 444 |
-
data={"
|
| 445 |
),
|
| 446 |
],
|
| 447 |
total_count=2,
|
|
@@ -452,7 +452,7 @@ class TestResultHarmonizer:
|
|
| 452 |
|
| 453 |
aggregated = result_harmonizer.aggregate_results(
|
| 454 |
harmonized,
|
| 455 |
-
group_by=["
|
| 456 |
aggregations={"contract_value": "sum"}
|
| 457 |
)
|
| 458 |
|
|
@@ -465,11 +465,11 @@ class TestResultHarmonizer:
|
|
| 465 |
results=[
|
| 466 |
HarmonizedRow(
|
| 467 |
customer_id="customer_a",
|
| 468 |
-
data={"
|
| 469 |
),
|
| 470 |
HarmonizedRow(
|
| 471 |
customer_id="customer_a",
|
| 472 |
-
data={"
|
| 473 |
),
|
| 474 |
],
|
| 475 |
total_count=2,
|
|
@@ -480,7 +480,7 @@ class TestResultHarmonizer:
|
|
| 480 |
|
| 481 |
aggregated = result_harmonizer.aggregate_results(
|
| 482 |
harmonized,
|
| 483 |
-
group_by=["
|
| 484 |
aggregations={"contract_value": "avg"}
|
| 485 |
)
|
| 486 |
|
|
|
|
| 157 |
assert harmonized["contract_status"] == "Active"
|
| 158 |
assert harmonized["contract_value"] == 100000.0
|
| 159 |
|
| 160 |
+
def test_harmonize_row_with_status(self, result_harmonizer):
|
| 161 |
+
"""Test harmonizing a row with status normalization."""
|
| 162 |
row = {
|
| 163 |
"contract_id": "A001",
|
| 164 |
+
"status": "active"
|
| 165 |
}
|
| 166 |
|
| 167 |
field_mappings = {
|
| 168 |
"contract_id": "contract_identifier",
|
| 169 |
+
"status": "contract_status"
|
| 170 |
}
|
| 171 |
|
| 172 |
harmonized = result_harmonizer._harmonize_row(row, "customer_a", field_mappings)
|
| 173 |
|
| 174 |
+
assert harmonized["contract_status"] == "active"
|
| 175 |
|
| 176 |
|
| 177 |
class TestResultHarmonizer:
|
|
|
|
| 393 |
assert all(row.data["contract_status"] == "Active" for row in filtered_result.results)
|
| 394 |
|
| 395 |
def test_aggregate_results_count(self, result_harmonizer):
|
| 396 |
+
"""Test aggregating results with count by status."""
|
| 397 |
harmonized = HarmonizedResult(
|
| 398 |
results=[
|
| 399 |
HarmonizedRow(
|
| 400 |
customer_id="customer_a",
|
| 401 |
+
data={"contract_status": "active", "contract_value": 100000}
|
| 402 |
),
|
| 403 |
HarmonizedRow(
|
| 404 |
customer_id="customer_a",
|
| 405 |
+
data={"contract_status": "active", "contract_value": 200000}
|
| 406 |
),
|
| 407 |
HarmonizedRow(
|
| 408 |
customer_id="customer_a",
|
| 409 |
+
data={"contract_status": "inactive", "contract_value": 150000}
|
| 410 |
),
|
| 411 |
],
|
| 412 |
total_count=3,
|
|
|
|
| 417 |
|
| 418 |
aggregated = result_harmonizer.aggregate_results(
|
| 419 |
harmonized,
|
| 420 |
+
group_by=["contract_status"],
|
| 421 |
aggregations={"contract_value": "count"}
|
| 422 |
)
|
| 423 |
|
| 424 |
+
assert aggregated.total_count == 2 # Two statuses
|
| 425 |
|
| 426 |
+
# Find active group
|
| 427 |
+
active_row = next(
|
| 428 |
+
(r for r in aggregated.results if r.data.get("contract_status") == "active"),
|
| 429 |
None
|
| 430 |
)
|
| 431 |
+
assert active_row is not None
|
| 432 |
+
assert active_row.data["contract_value_count"] == 2
|
| 433 |
|
| 434 |
def test_aggregate_results_sum(self, result_harmonizer):
|
| 435 |
"""Test aggregating results with sum."""
|
|
|
|
| 437 |
results=[
|
| 438 |
HarmonizedRow(
|
| 439 |
customer_id="customer_a",
|
| 440 |
+
data={"contract_status": "active", "contract_value": 100000}
|
| 441 |
),
|
| 442 |
HarmonizedRow(
|
| 443 |
customer_id="customer_a",
|
| 444 |
+
data={"contract_status": "active", "contract_value": 200000}
|
| 445 |
),
|
| 446 |
],
|
| 447 |
total_count=2,
|
|
|
|
| 452 |
|
| 453 |
aggregated = result_harmonizer.aggregate_results(
|
| 454 |
harmonized,
|
| 455 |
+
group_by=["contract_status"],
|
| 456 |
aggregations={"contract_value": "sum"}
|
| 457 |
)
|
| 458 |
|
|
|
|
| 465 |
results=[
|
| 466 |
HarmonizedRow(
|
| 467 |
customer_id="customer_a",
|
| 468 |
+
data={"contract_status": "active", "contract_value": 100000}
|
| 469 |
),
|
| 470 |
HarmonizedRow(
|
| 471 |
customer_id="customer_a",
|
| 472 |
+
data={"contract_status": "active", "contract_value": 200000}
|
| 473 |
),
|
| 474 |
],
|
| 475 |
total_count=2,
|
|
|
|
| 480 |
|
| 481 |
aggregated = result_harmonizer.aggregate_results(
|
| 482 |
harmonized,
|
| 483 |
+
group_by=["contract_status"],
|
| 484 |
aggregations={"contract_value": "avg"}
|
| 485 |
)
|
| 486 |
|