sanzgiri commited on
Commit
4bb196e
·
1 Parent(s): a4cd896

v2.0: Dynamic result formatting and natural language customer selection

Browse files

Major features:
- Added target_customers field to SemanticQueryPlan for natural language database selection
- Updated query understanding agent to extract customer references (customer_a, customer A, etc.)
- Modified orchestrator to use extracted target_customers from query plan
- Fixed result harmonizer to include aggregation columns (count, sum, avg, etc.)
- Implemented dynamic table formatting in app.py that filters null columns
- Customer column now displays as A, B, C instead of customer_a, customer_b
- Removed customer_name and industry_sector concepts from knowledge graph
- Updated all tests to reflect schema changes (156 tests passing)

Key improvements:
- Queries like 'show contracts from customer A' now work naturally
- Count/sum/aggregation queries display correctly with only relevant columns
- No more empty columns with dashes in result tables
- Smart Customer column display (shown when querying multiple databases)
- All query types supported: regular, count, sum, filtered, multi-customer

app.py CHANGED
@@ -45,44 +45,101 @@ def format_result_table(result, limit: int = 10) -> str:
45
  first_row = result.results[0]
46
  all_columns = list(first_row.data.keys())
47
 
48
- # Define preferred field order and nice names
49
- field_order = [
50
- 'contract_identifier',
51
- 'customer_name',
52
- 'contract_name',
53
- 'industry_sector',
54
- 'contract_value',
55
- 'contract_status',
56
- 'contract_expiration'
57
- ]
58
-
59
- nice_names = {
60
- 'contract_identifier': 'Contract ID',
61
- 'contract_name': 'Contract Name',
62
- 'customer_name': 'Company', # Changed from Customer to Company
63
- 'contract_value': 'Value',
64
- 'contract_status': 'Status',
65
- 'contract_expiration': 'Expiration',
66
- 'contract_start': 'Start Date',
67
- 'industry_sector': 'Industry'
68
- }
69
-
70
- # Order columns: preferred order first, then any remaining
71
- columns = []
72
- for field in field_order:
73
- if field in all_columns:
74
- columns.append(field)
75
-
76
- # Add any remaining columns not in preferred order
77
  for col in all_columns:
78
- if col not in columns:
79
- columns.append(col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # Build markdown table with better formatting
82
  lines = []
83
 
84
  # Header with nicer column names
85
- display_cols = [nice_names.get(col, col) for col in columns]
 
 
 
 
 
 
 
 
86
  header = "| " + " | ".join(display_cols) + " |"
87
  separator = "|" + "|".join(["---" for _ in columns]) + "|"
88
  lines.append(header)
@@ -97,9 +154,17 @@ def format_result_table(result, limit: int = 10) -> str:
97
  # Format values for better display
98
  if val is None or val == "None":
99
  val = "—"
100
- elif isinstance(val, (int, float)) and col == 'contract_value':
101
- # Format currency values
102
- val = f"${val:,.0f}"
 
 
 
 
 
 
 
 
103
  else:
104
  val = str(val)
105
  values.append(val)
@@ -206,11 +271,15 @@ Welcome! I'm running in **{mode}**.
206
  I can help you query customer databases using natural language. Just ask me questions like:
207
  - "Show me all contracts"
208
  - "Find active contracts with value over 10000"
 
 
209
  - "Count contracts by status"
210
- - "List all customers"
 
211
 
212
  ### Available Commands:
213
- - `/customers` - List available customers
 
214
  - `/limit <number>` - Set max rows to display (default: 10)
215
  - `/debug on/off` - Toggle debug mode
216
  - `/stats` - Show query statistics
@@ -282,18 +351,32 @@ async def main(message: cl.Message):
282
  # Show available fields
283
  if result.results:
284
  fields = list(result.results[0].data.keys())
 
 
 
 
 
285
  field_names = {
 
286
  'contract_identifier': 'Contract ID',
287
- 'contract_name': 'Contract Name',
288
- 'customer_name': 'Company', # Changed from Customer to Company
289
  'contract_value': 'Value',
290
  'contract_status': 'Status',
291
  'contract_expiration': 'Expiration Date',
292
  'contract_start': 'Start Date',
293
- 'industry_sector': 'Industry'
 
 
 
 
294
  }
295
- field_list = ', '.join([f"`{field_names.get(f, f)}`" for f in fields])
296
- content_parts.append(f"**Showing {len(fields)} fields:** {field_list}\n")
 
 
 
 
 
 
297
 
298
  # Results table
299
  content_parts.append("### ✅ Query Results")
@@ -408,14 +491,21 @@ async def handle_command(command: str, orchestrator: ChatOrchestrator, debug_mod
408
  **Example Queries:**
409
  - "Show me all contracts"
410
  - "Find active contracts"
 
 
411
  - "Count contracts by status"
412
  - "List contracts with value over 10000"
413
- - "Show contracts expiring soon"
 
 
 
 
 
414
 
415
  **Commands:**
416
- - `/customers` - List available customers
417
- - `/select <customer_id>` - Query specific customer(s)
418
- - `/select all` - Query all customers (default)
419
  - `/limit <number>` - Set max rows to display (default: 10)
420
  - `/debug on/off` - Toggle debug mode
421
  - `/stats` - Show query statistics
@@ -425,8 +515,8 @@ async def handle_command(command: str, orchestrator: ChatOrchestrator, debug_mod
425
  **Tips:**
426
  - Use natural language - I'll understand it!
427
  - Be specific about what you want to see
428
- - Use filters like "active", "over 1000", "expiring soon"
429
- - Enable debug mode to see SQL queries
430
  """
431
  await cl.Message(content=help_msg).send()
432
 
 
45
  first_row = result.results[0]
46
  all_columns = list(first_row.data.keys())
47
 
48
+ # Add source database to the data (from customer_id in HarmonizedRow)
49
+ # Convert customer_a -> A, customer_b -> B, etc. for display
50
+ for row in result.results:
51
+ if 'source_db' not in row.data:
52
+ # Extract letter from customer_id (customer_a -> A)
53
+ customer_letter = row.customer_id.replace('customer_', '').upper()
54
+ row.data['source_db'] = customer_letter
55
+
56
+ # Add source_db to all_columns if not already there
57
+ if 'source_db' not in all_columns:
58
+ all_columns.insert(0, 'source_db')
59
+
60
+ # Filter to only columns that have non-null values in at least one row
61
+ columns_with_values = set()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  for col in all_columns:
63
+ if any(row.data.get(col) is not None for row in result.results):
64
+ columns_with_values.add(col)
65
+
66
+ # Check if multiple customers are being queried
67
+ multiple_customers = len(result.customers_queried) > 1
68
+
69
+ # Always include source_db if querying multiple customers
70
+ if multiple_customers and 'source_db' not in columns_with_values:
71
+ columns_with_values.add('source_db')
72
+
73
+ # Check if this is an aggregation/count query (no contract_identifier or all null)
74
+ is_aggregation = 'contract_identifier' not in columns_with_values
75
+
76
+ if is_aggregation:
77
+ # For aggregations, show source_db first, then all other non-null columns
78
+ columns = []
79
+ if 'source_db' in columns_with_values:
80
+ columns.append('source_db')
81
+ for col in columns_with_values:
82
+ if col != 'source_db':
83
+ columns.append(col)
84
+
85
+ # Nice names for aggregation fields
86
+ nice_names = {
87
+ 'source_db': 'Customer',
88
+ 'count_contract_identifier': 'Count',
89
+ 'sum_contract_value': 'Total Value',
90
+ 'avg_contract_value': 'Avg Value',
91
+ 'max_contract_value': 'Max Value',
92
+ 'min_contract_value': 'Min Value',
93
+ 'count': 'Count',
94
+ 'sum': 'Sum',
95
+ 'average': 'Average',
96
+ 'max': 'Max',
97
+ 'min': 'Min'
98
+ }
99
+ else:
100
+ # For regular queries, use preferred field order (only for columns with values)
101
+ field_order = [
102
+ 'source_db', # Show which customer (A, B, C, etc.)
103
+ 'contract_identifier', # Contract ID
104
+ 'contract_value',
105
+ 'contract_status',
106
+ 'contract_expiration',
107
+ 'contract_start'
108
+ ]
109
+
110
+ nice_names = {
111
+ 'source_db': 'Customer', # Which database (A, B, C, D, E, F)
112
+ 'contract_identifier': 'Contract ID',
113
+ 'contract_value': 'Value',
114
+ 'contract_status': 'Status',
115
+ 'contract_expiration': 'Expiration',
116
+ 'contract_start': 'Start Date'
117
+ }
118
+
119
+ # Order columns: preferred order first (only include if they have values)
120
+ columns = []
121
+ for field in field_order:
122
+ if field in columns_with_values:
123
+ columns.append(field)
124
+
125
+ # Add any remaining columns not in preferred order (that have values)
126
+ for col in columns_with_values:
127
+ if col not in columns:
128
+ columns.append(col)
129
 
130
  # Build markdown table with better formatting
131
  lines = []
132
 
133
  # Header with nicer column names
134
+ display_cols = []
135
+ for col in columns:
136
+ # Get nice name or convert snake_case to Title Case
137
+ if col in nice_names:
138
+ display_cols.append(nice_names[col])
139
+ else:
140
+ # Convert snake_case to Title Case
141
+ display_cols.append(col.replace('_', ' ').title())
142
+
143
  header = "| " + " | ".join(display_cols) + " |"
144
  separator = "|" + "|".join(["---" for _ in columns]) + "|"
145
  lines.append(header)
 
154
  # Format values for better display
155
  if val is None or val == "None":
156
  val = "—"
157
+ elif isinstance(val, (int, float)):
158
+ # Format numbers based on column type
159
+ if 'value' in col.lower() or 'sum' in col.lower():
160
+ # Format currency values
161
+ val = f"${val:,.0f}"
162
+ elif 'avg' in col.lower() or 'average' in col.lower():
163
+ # Format averages with 2 decimals
164
+ val = f"${val:,.2f}" if 'value' in col.lower() else f"{val:,.2f}"
165
+ else:
166
+ # Format counts and other integers
167
+ val = f"{val:,}"
168
  else:
169
  val = str(val)
170
  values.append(val)
 
271
  I can help you query customer databases using natural language. Just ask me questions like:
272
  - "Show me all contracts"
273
  - "Find active contracts with value over 10000"
274
+ - "Show contracts from customer A" or "from customer_a"
275
+ - "List contracts in customer B and C"
276
  - "Count contracts by status"
277
+
278
+ **Note:** Results show Customer column (A, B, C, etc.) indicating which database each contract is from.
279
 
280
  ### Available Commands:
281
+ - `/customers` - List available databases (customer_a through customer_f)
282
+ - `/select <customer_id>` - Query specific database (e.g., `/select customer_a`)
283
  - `/limit <number>` - Set max rows to display (default: 10)
284
  - `/debug on/off` - Toggle debug mode
285
  - `/stats` - Show query statistics
 
351
  # Show available fields
352
  if result.results:
353
  fields = list(result.results[0].data.keys())
354
+ # Filter out None values to get actual fields returned
355
+ actual_fields = [f for f in fields if any(
356
+ row.data.get(f) is not None for row in result.results
357
+ )]
358
+
359
  field_names = {
360
+ 'source_db': 'Customer',
361
  'contract_identifier': 'Contract ID',
 
 
362
  'contract_value': 'Value',
363
  'contract_status': 'Status',
364
  'contract_expiration': 'Expiration Date',
365
  'contract_start': 'Start Date',
366
+ 'count_contract_identifier': 'Count',
367
+ 'sum_contract_value': 'Total Value',
368
+ 'avg_contract_value': 'Avg Value',
369
+ 'max_contract_value': 'Max Value',
370
+ 'min_contract_value': 'Min Value'
371
  }
372
+
373
+ field_display = []
374
+ for f in actual_fields:
375
+ nice_name = field_names.get(f, f.replace('_', ' ').title())
376
+ field_display.append(f"`{nice_name}`")
377
+
378
+ field_list = ', '.join(field_display)
379
+ content_parts.append(f"**Showing {len(actual_fields)} fields:** {field_list}\n")
380
 
381
  # Results table
382
  content_parts.append("### ✅ Query Results")
 
491
  **Example Queries:**
492
  - "Show me all contracts"
493
  - "Find active contracts"
494
+ - "Show contracts from customer A" or "from customer_a"
495
+ - "Query customer B and customer C databases"
496
  - "Count contracts by status"
497
  - "List contracts with value over 10000"
498
+ - "Show active contracts in customer_a expiring in 2026"
499
+
500
+ **Important Notes:**
501
+ - You can query specific customers by saying "customer A", "customer_a", "from customer B", etc.
502
+ - Without specifying, queries run across all available databases
503
+ - **Customer** column in results shows which database (A, B, C, etc.) each contract came from
504
 
505
  **Commands:**
506
+ - `/customers` - List available databases (customer_a through customer_f)
507
+ - `/select <customer_id>` - Query specific database (e.g., `/select customer_a`)
508
+ - `/select all` - Query all databases (default)
509
  - `/limit <number>` - Set max rows to display (default: 10)
510
  - `/debug on/off` - Toggle debug mode
511
  - `/stats` - Show query statistics
 
515
  **Tips:**
516
  - Use natural language - I'll understand it!
517
  - Be specific about what you want to see
518
+ - Use filters like "active", "over 10000", "expiring in 90 days"
519
+ - Enable debug mode to see SQL queries and semantic plans
520
  """
521
  await cl.Message(content=help_msg).send()
522
 
initialize_kg.py CHANGED
@@ -280,137 +280,13 @@ def initialize_knowledge_graph() -> SchemaKnowledgeGraph:
280
  )
281
 
282
  # ========================================================================
283
- # 5. INDUSTRY SECTOR
284
  # ========================================================================
285
- print(" Adding concept: industry_sector")
286
- kg.add_concept(
287
- concept_id="industry_sector",
288
- concept_name="Industry Sector",
289
- description="Business industry or vertical of the customer",
290
- aliases=["industry", "sector", "vertical", "business_sector"]
291
- )
292
-
293
- kg.add_customer_mapping(
294
- concept_id="industry_sector",
295
- customer_id="customer_a",
296
- table_name="contracts",
297
- column_name="industry",
298
- data_type="TEXT",
299
- semantic_type=SemanticType.TEXT
300
- )
301
-
302
- kg.add_customer_mapping(
303
- concept_id="industry_sector",
304
- customer_id="customer_b",
305
- table_name="contract_headers",
306
- column_name="sector",
307
- data_type="TEXT",
308
- semantic_type=SemanticType.TEXT
309
- )
310
-
311
- kg.add_customer_mapping(
312
- concept_id="industry_sector",
313
- customer_id="customer_c",
314
- table_name="contracts",
315
- column_name="business_sector",
316
- data_type="TEXT",
317
- semantic_type=SemanticType.TEXT
318
- )
319
-
320
- kg.add_customer_mapping(
321
- concept_id="industry_sector",
322
- customer_id="customer_d",
323
- table_name="contracts",
324
- column_name="industry",
325
- data_type="TEXT",
326
- semantic_type=SemanticType.TEXT
327
- )
328
-
329
- kg.add_customer_mapping(
330
- concept_id="industry_sector",
331
- customer_id="customer_e",
332
- table_name="contracts",
333
- column_name="industry",
334
- data_type="TEXT",
335
- semantic_type=SemanticType.TEXT
336
- )
337
-
338
- kg.add_customer_mapping(
339
- concept_id="industry_sector",
340
- customer_id="customer_f",
341
- table_name="contracts",
342
- column_name="sector",
343
- data_type="TEXT",
344
- semantic_type=SemanticType.TEXT
345
- )
346
-
347
- # ========================================================================
348
- # 6. CUSTOMER NAME
349
- # ========================================================================
350
- print(" Adding concept: customer_name")
351
- kg.add_concept(
352
- concept_id="customer_name",
353
- concept_name="Customer Name",
354
- description="Name of the customer/client organization",
355
- aliases=["client", "account", "organization", "client_name"]
356
- )
357
-
358
- kg.add_customer_mapping(
359
- concept_id="customer_name",
360
- customer_id="customer_a",
361
- table_name="contracts",
362
- column_name="customer_name",
363
- data_type="TEXT",
364
- semantic_type=SemanticType.TEXT
365
- )
366
-
367
- kg.add_customer_mapping(
368
- concept_id="customer_name",
369
- customer_id="customer_b",
370
- table_name="contract_headers",
371
- column_name="client_name",
372
- data_type="TEXT",
373
- semantic_type=SemanticType.TEXT
374
- )
375
-
376
- kg.add_customer_mapping(
377
- concept_id="customer_name",
378
- customer_id="customer_c",
379
- table_name="contracts",
380
- column_name="account",
381
- data_type="TEXT",
382
- semantic_type=SemanticType.TEXT
383
- )
384
-
385
- kg.add_customer_mapping(
386
- concept_id="customer_name",
387
- customer_id="customer_d",
388
- table_name="contracts",
389
- column_name="customer_org",
390
- data_type="TEXT",
391
- semantic_type=SemanticType.TEXT
392
- )
393
-
394
- kg.add_customer_mapping(
395
- concept_id="customer_name",
396
- customer_id="customer_e",
397
- table_name="contracts",
398
- column_name="customer_name",
399
- data_type="TEXT",
400
- semantic_type=SemanticType.TEXT
401
- )
402
-
403
- kg.add_customer_mapping(
404
- concept_id="customer_name",
405
- customer_id="customer_f",
406
- table_name="contracts",
407
- column_name="account",
408
- data_type="TEXT",
409
- semantic_type=SemanticType.TEXT
410
- )
411
-
412
- # ========================================================================
413
- # 7. CONTRACT START
414
  # ========================================================================
415
  print(" Adding concept: contract_start")
416
  kg.add_concept(
 
280
  )
281
 
282
  # ========================================================================
283
+ # 5. CONTRACT START
284
  # ========================================================================
285
+ # NOTE: industry_sector and customer_name concepts were removed.
286
+ # - customer_name: Refers to company names in contracts (e.g., "Global Tech Inc")
287
+ # which conflicts with database selection (customer_a, customer_b, etc.)
288
+ # - industry_sector: Internal attribute that doesn't represent typical
289
+ # business queries users would make.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  # ========================================================================
291
  print(" Adding concept: contract_start")
292
  kg.add_concept(
knowledge_graph.json CHANGED
@@ -271,140 +271,6 @@
271
  }
272
  }
273
  },
274
- "industry_sector": {
275
- "concept_id": "industry_sector",
276
- "concept_name": "Industry Sector",
277
- "description": "Business industry or vertical of the customer",
278
- "aliases": [
279
- "industry",
280
- "sector",
281
- "vertical",
282
- "business_sector"
283
- ],
284
- "customer_mappings": {
285
- "customer_a": {
286
- "customer_id": "customer_a",
287
- "table_name": "contracts",
288
- "column_name": "industry",
289
- "data_type": "TEXT",
290
- "semantic_type": "text",
291
- "transformation": null,
292
- "join_requirements": []
293
- },
294
- "customer_b": {
295
- "customer_id": "customer_b",
296
- "table_name": "contract_headers",
297
- "column_name": "sector",
298
- "data_type": "TEXT",
299
- "semantic_type": "text",
300
- "transformation": null,
301
- "join_requirements": []
302
- },
303
- "customer_c": {
304
- "customer_id": "customer_c",
305
- "table_name": "contracts",
306
- "column_name": "business_sector",
307
- "data_type": "TEXT",
308
- "semantic_type": "text",
309
- "transformation": null,
310
- "join_requirements": []
311
- },
312
- "customer_d": {
313
- "customer_id": "customer_d",
314
- "table_name": "contracts",
315
- "column_name": "industry",
316
- "data_type": "TEXT",
317
- "semantic_type": "text",
318
- "transformation": null,
319
- "join_requirements": []
320
- },
321
- "customer_e": {
322
- "customer_id": "customer_e",
323
- "table_name": "contracts",
324
- "column_name": "industry",
325
- "data_type": "TEXT",
326
- "semantic_type": "text",
327
- "transformation": null,
328
- "join_requirements": []
329
- },
330
- "customer_f": {
331
- "customer_id": "customer_f",
332
- "table_name": "contracts",
333
- "column_name": "sector",
334
- "data_type": "TEXT",
335
- "semantic_type": "text",
336
- "transformation": null,
337
- "join_requirements": []
338
- }
339
- }
340
- },
341
- "customer_name": {
342
- "concept_id": "customer_name",
343
- "concept_name": "Customer Name",
344
- "description": "Name of the customer/client organization",
345
- "aliases": [
346
- "client",
347
- "account",
348
- "organization",
349
- "client_name"
350
- ],
351
- "customer_mappings": {
352
- "customer_a": {
353
- "customer_id": "customer_a",
354
- "table_name": "contracts",
355
- "column_name": "customer_name",
356
- "data_type": "TEXT",
357
- "semantic_type": "text",
358
- "transformation": null,
359
- "join_requirements": []
360
- },
361
- "customer_b": {
362
- "customer_id": "customer_b",
363
- "table_name": "contract_headers",
364
- "column_name": "client_name",
365
- "data_type": "TEXT",
366
- "semantic_type": "text",
367
- "transformation": null,
368
- "join_requirements": []
369
- },
370
- "customer_c": {
371
- "customer_id": "customer_c",
372
- "table_name": "contracts",
373
- "column_name": "account",
374
- "data_type": "TEXT",
375
- "semantic_type": "text",
376
- "transformation": null,
377
- "join_requirements": []
378
- },
379
- "customer_d": {
380
- "customer_id": "customer_d",
381
- "table_name": "contracts",
382
- "column_name": "customer_org",
383
- "data_type": "TEXT",
384
- "semantic_type": "text",
385
- "transformation": null,
386
- "join_requirements": []
387
- },
388
- "customer_e": {
389
- "customer_id": "customer_e",
390
- "table_name": "contracts",
391
- "column_name": "customer_name",
392
- "data_type": "TEXT",
393
- "semantic_type": "text",
394
- "transformation": null,
395
- "join_requirements": []
396
- },
397
- "customer_f": {
398
- "customer_id": "customer_f",
399
- "table_name": "contracts",
400
- "column_name": "account",
401
- "data_type": "TEXT",
402
- "semantic_type": "text",
403
- "transformation": null,
404
- "join_requirements": []
405
- }
406
- }
407
- },
408
  "contract_start": {
409
  "concept_id": "contract_start",
410
  "concept_name": "Contract Start Date",
@@ -471,73 +337,6 @@
471
  "join_requirements": []
472
  }
473
  }
474
- },
475
- "contract_name": {
476
- "concept_id": "contract_name",
477
- "concept_name": "Contract Name",
478
- "description": "The name or title of the contract",
479
- "aliases": [
480
- "name",
481
- "title",
482
- "contract title",
483
- "agreement name"
484
- ],
485
- "customer_mappings": {
486
- "customer_a": {
487
- "customer_id": "customer_a",
488
- "table_name": "contracts",
489
- "column_name": "contract_name",
490
- "data_type": "TEXT",
491
- "semantic_type": "text",
492
- "transformation": null,
493
- "join_requirements": []
494
- },
495
- "customer_b": {
496
- "customer_id": "customer_b",
497
- "table_name": "contract_headers",
498
- "column_name": "name",
499
- "data_type": "TEXT",
500
- "semantic_type": "text",
501
- "transformation": null,
502
- "join_requirements": []
503
- },
504
- "customer_c": {
505
- "customer_id": "customer_c",
506
- "table_name": "contracts",
507
- "column_name": "name",
508
- "data_type": "TEXT",
509
- "semantic_type": "text",
510
- "transformation": null,
511
- "join_requirements": []
512
- },
513
- "customer_d": {
514
- "customer_id": "customer_d",
515
- "table_name": "contracts",
516
- "column_name": "name",
517
- "data_type": "TEXT",
518
- "semantic_type": "text",
519
- "transformation": null,
520
- "join_requirements": []
521
- },
522
- "customer_e": {
523
- "customer_id": "customer_e",
524
- "table_name": "contracts",
525
- "column_name": "contract_title",
526
- "data_type": "TEXT",
527
- "semantic_type": "text",
528
- "transformation": null,
529
- "join_requirements": []
530
- },
531
- "customer_f": {
532
- "customer_id": "customer_f",
533
- "table_name": "contracts",
534
- "column_name": "agreement_name",
535
- "data_type": "TEXT",
536
- "semantic_type": "text",
537
- "transformation": null,
538
- "join_requirements": []
539
- }
540
- }
541
  }
542
  },
543
  "transformations": {
 
271
  }
272
  }
273
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  "contract_start": {
275
  "concept_id": "contract_start",
276
  "concept_name": "Contract Start Date",
 
337
  "join_requirements": []
338
  }
339
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
  }
341
  },
342
  "transformations": {
schema_translator/agents/query_understanding.py CHANGED
@@ -43,6 +43,18 @@ class QueryUnderstandingAgent:
43
  Available Semantic Concepts:
44
  {concept_list}
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  Your task is to:
47
  1. Identify the user's query intent (find_contracts, count_contracts, aggregate_values)
48
  2. Determine which semantic concepts are relevant
@@ -88,7 +100,8 @@ Return ONLY valid JSON matching this schema:
88
  "function": "count" | "sum" | "average",
89
  "concept": "concept_name"
90
  }}], // optional, for aggregate_values intent
91
- "limit": 10 // optional, for find_contracts intent
 
92
  }}
93
 
94
  IMPORTANT:
@@ -110,7 +123,8 @@ Result:
110
  "intent": "find_contracts",
111
  "filters": [],
112
  "projections": [], // Empty projections means return ALL available fields
113
- "limit": 100
 
114
  }}
115
 
116
  Query: "Show me all active contracts"
@@ -125,7 +139,18 @@ Result:
125
  }}
126
  ],
127
  "projections": [], // Empty projections means return ALL available fields
128
- "limit": 100
 
 
 
 
 
 
 
 
 
 
 
129
  }}
130
 
131
  Query: "How many contracts expire in the next 90 days?"
@@ -156,26 +181,21 @@ Result:
156
  "value": 2000000
157
  }}
158
  ],
159
- "projections": ["contract_identifier", "contract_value", "customer_name"]
160
  }}
161
 
162
- Query: "List technology contracts expiring in 2026"
163
  Result:
164
  {{
165
  "intent": "find_contracts",
166
  "filters": [
167
- {{
168
- "concept": "industry_sector",
169
- "operator": "equals",
170
- "value": "Technology"
171
- }},
172
  {{
173
  "concept": "contract_expiration",
174
  "operator": "between",
175
  "value": ["2026-01-01", "2026-12-31"]
176
  }}
177
  ],
178
- "projections": ["contract_identifier", "industry_sector", "contract_expiration"],
179
  "limit": 50
180
  }}
181
 
 
43
  Available Semantic Concepts:
44
  {concept_list}
45
 
46
+ CRITICAL RULE - Database Selection:
47
+ When users mention "customer" or database references, extract them into target_customers field:
48
+ - "customer_a", "customer_b", "customer_c", "customer_d", "customer_e", "customer_f"
49
+ - "customer A", "customer B", "database A", "from customer_a", "customer a database"
50
+ - Any reference to which specific database(s) to query
51
+
52
+ DO NOT create semantic filters for these - use the target_customers field instead.
53
+ Examples:
54
+ - "Show contracts from customer_a" → target_customers: ["customer_a"]
55
+ - "Query customer A and B" → target_customers: ["customer_a", "customer_b"]
56
+ - "Show all contracts" → target_customers: null (means query all)
57
+
58
  Your task is to:
59
  1. Identify the user's query intent (find_contracts, count_contracts, aggregate_values)
60
  2. Determine which semantic concepts are relevant
 
100
  "function": "count" | "sum" | "average",
101
  "concept": "concept_name"
102
  }}], // optional, for aggregate_values intent
103
+ "limit": 10, // optional, for find_contracts intent
104
+ "target_customers": ["customer_a", "customer_b"] // optional, null or omit for all customers
105
  }}
106
 
107
  IMPORTANT:
 
123
  "intent": "find_contracts",
124
  "filters": [],
125
  "projections": [], // Empty projections means return ALL available fields
126
+ "limit": 100,
127
+ "target_customers": null // null means query all customers
128
  }}
129
 
130
  Query: "Show me all active contracts"
 
139
  }}
140
  ],
141
  "projections": [], // Empty projections means return ALL available fields
142
+ "limit": 100,
143
+ "target_customers": null
144
+ }}
145
+
146
+ Query: "Show contracts from customer_a" or "Query customer A database"
147
+ Result:
148
+ {{
149
+ "intent": "find_contracts",
150
+ "filters": [],
151
+ "projections": [],
152
+ "limit": 100,
153
+ "target_customers": ["customer_a"] // Extract database reference
154
  }}
155
 
156
  Query: "How many contracts expire in the next 90 days?"
 
181
  "value": 2000000
182
  }}
183
  ],
184
+ "projections": ["contract_identifier", "contract_value", "contract_status"]
185
  }}
186
 
187
+ Query: "List contracts expiring in 2026"
188
  Result:
189
  {{
190
  "intent": "find_contracts",
191
  "filters": [
 
 
 
 
 
192
  {{
193
  "concept": "contract_expiration",
194
  "operator": "between",
195
  "value": ["2026-01-01", "2026-12-31"]
196
  }}
197
  ],
198
+ "projections": ["contract_identifier", "contract_expiration", "contract_value"],
199
  "limit": 50
200
  }}
201
 
schema_translator/mock_data.py CHANGED
@@ -108,22 +108,18 @@ class MockDataGenerator:
108
  conn = sqlite3.connect(db_path)
109
  cursor = conn.cursor()
110
 
111
- # Create table
112
  cursor.execute("""
113
  CREATE TABLE IF NOT EXISTS contracts (
114
  contract_id INTEGER PRIMARY KEY,
115
- contract_name TEXT NOT NULL,
116
- customer_name TEXT NOT NULL,
117
  contract_value INTEGER NOT NULL,
118
  status TEXT NOT NULL,
119
  expiry_date TEXT NOT NULL,
120
- start_date TEXT NOT NULL,
121
- industry TEXT NOT NULL
122
  )
123
  """)
124
 
125
  # Generate data
126
- industries = self.INDUSTRIES["customer_a"]
127
  for i in range(1, 51):
128
  start_date, expiry_date = self.generate_dates()
129
 
@@ -135,18 +131,14 @@ class MockDataGenerator:
135
 
136
  cursor.execute("""
137
  INSERT INTO contracts
138
- (contract_id, contract_name, customer_name, contract_value, status,
139
- expiry_date, start_date, industry)
140
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
141
  """, (
142
  i,
143
- self.generate_contract_name(i),
144
- self.generate_company_name(),
145
  self.generate_contract_value(is_annual=False),
146
  status,
147
  expiry_date.strftime("%Y-%m-%d"),
148
- start_date.strftime("%Y-%m-%d"),
149
- random.choice(industries)
150
  ))
151
 
152
  conn.commit()
@@ -159,15 +151,12 @@ class MockDataGenerator:
159
  conn = sqlite3.connect(db_path)
160
  cursor = conn.cursor()
161
 
162
- # Create tables
163
  cursor.execute("""
164
  CREATE TABLE IF NOT EXISTS contract_headers (
165
  id INTEGER PRIMARY KEY,
166
- contract_name TEXT NOT NULL,
167
- client_name TEXT NOT NULL,
168
  contract_value INTEGER NOT NULL,
169
- start_date TEXT NOT NULL,
170
- sector TEXT NOT NULL
171
  )
172
  """)
173
 
@@ -191,22 +180,18 @@ class MockDataGenerator:
191
  """)
192
 
193
  # Generate data
194
- industries = self.INDUSTRIES["customer_b"]
195
  for i in range(1, 51):
196
  start_date, expiry_date = self.generate_dates()
197
 
198
  # Insert header
199
  cursor.execute("""
200
  INSERT INTO contract_headers
201
- (id, contract_name, client_name, contract_value, start_date, sector)
202
- VALUES (?, ?, ?, ?, ?, ?)
203
  """, (
204
  i,
205
- self.generate_contract_name(i),
206
- self.generate_company_name(),
207
  self.generate_contract_value(is_annual=False),
208
- start_date.strftime("%Y-%m-%d"),
209
- random.choice(industries)
210
  ))
211
 
212
  # Insert status history (1-3 status changes)
@@ -244,22 +229,18 @@ class MockDataGenerator:
244
  conn = sqlite3.connect(db_path)
245
  cursor = conn.cursor()
246
 
247
- # Create table with different column names
248
  cursor.execute("""
249
  CREATE TABLE IF NOT EXISTS contracts (
250
  id INTEGER PRIMARY KEY,
251
- name TEXT NOT NULL,
252
- account TEXT NOT NULL,
253
  total_value INTEGER NOT NULL,
254
  current_status TEXT NOT NULL,
255
  expiration_date TEXT NOT NULL,
256
- inception_date TEXT NOT NULL,
257
- business_sector TEXT NOT NULL
258
  )
259
  """)
260
 
261
  # Generate data
262
- industries = self.INDUSTRIES["customer_c"]
263
  for i in range(1, 51):
264
  start_date, expiry_date = self.generate_dates()
265
 
@@ -270,18 +251,14 @@ class MockDataGenerator:
270
 
271
  cursor.execute("""
272
  INSERT INTO contracts
273
- (id, name, account, total_value, current_status,
274
- expiration_date, inception_date, business_sector)
275
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
276
  """, (
277
  i,
278
- self.generate_contract_name(i),
279
- self.generate_company_name(),
280
  self.generate_contract_value(is_annual=False),
281
  status,
282
  expiry_date.strftime("%Y-%m-%d"),
283
- start_date.strftime("%Y-%m-%d"),
284
- random.choice(industries)
285
  ))
286
 
287
  conn.commit()
@@ -294,22 +271,18 @@ class MockDataGenerator:
294
  conn = sqlite3.connect(db_path)
295
  cursor = conn.cursor()
296
 
297
- # Create table with days_remaining instead of date
298
  cursor.execute("""
299
  CREATE TABLE IF NOT EXISTS contracts (
300
  contract_id INTEGER PRIMARY KEY,
301
- contract_title TEXT NOT NULL,
302
- customer_org TEXT NOT NULL,
303
  contract_value INTEGER NOT NULL,
304
  status TEXT NOT NULL,
305
  days_remaining INTEGER NOT NULL,
306
- start_date TEXT NOT NULL,
307
- industry TEXT NOT NULL
308
  )
309
  """)
310
 
311
  # Generate data
312
- industries = self.INDUSTRIES["customer_d"]
313
  for i in range(1, 51):
314
  start_date, expiry_date = self.generate_dates()
315
 
@@ -323,18 +296,14 @@ class MockDataGenerator:
323
 
324
  cursor.execute("""
325
  INSERT INTO contracts
326
- (contract_id, contract_title, customer_org, contract_value, status,
327
- days_remaining, start_date, industry)
328
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
329
  """, (
330
  i,
331
- self.generate_contract_name(i),
332
- self.generate_company_name(),
333
  self.generate_contract_value(is_annual=False),
334
  status,
335
  days_remaining,
336
- start_date.strftime("%Y-%m-%d"),
337
- random.choice(industries)
338
  ))
339
 
340
  conn.commit()
@@ -347,23 +316,19 @@ class MockDataGenerator:
347
  conn = sqlite3.connect(db_path)
348
  cursor = conn.cursor()
349
 
350
- # Create table
351
  cursor.execute("""
352
  CREATE TABLE IF NOT EXISTS contracts (
353
  contract_id INTEGER PRIMARY KEY,
354
- contract_name TEXT NOT NULL,
355
- customer_name TEXT NOT NULL,
356
  contract_value INTEGER NOT NULL,
357
  term_years REAL NOT NULL,
358
  status TEXT NOT NULL,
359
  expiry_date TEXT NOT NULL,
360
- start_date TEXT NOT NULL,
361
- industry TEXT NOT NULL
362
  )
363
  """)
364
 
365
  # Generate data
366
- industries = self.INDUSTRIES["customer_e"]
367
  for i in range(1, 51):
368
  start_date, expiry_date = self.generate_dates()
369
 
@@ -378,19 +343,15 @@ class MockDataGenerator:
378
 
379
  cursor.execute("""
380
  INSERT INTO contracts
381
- (contract_id, contract_name, customer_name, contract_value, term_years,
382
- status, expiry_date, start_date, industry)
383
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
384
  """, (
385
  i,
386
- self.generate_contract_name(i),
387
- self.generate_company_name(),
388
  self.generate_contract_value(is_annual=False),
389
  term_years,
390
  status,
391
  expiry_date.strftime("%Y-%m-%d"),
392
- start_date.strftime("%Y-%m-%d"),
393
- random.choice(industries)
394
  ))
395
 
396
  conn.commit()
@@ -403,23 +364,19 @@ class MockDataGenerator:
403
  conn = sqlite3.connect(db_path)
404
  cursor = conn.cursor()
405
 
406
- # Create table
407
  cursor.execute("""
408
  CREATE TABLE IF NOT EXISTS contracts (
409
  contract_id INTEGER PRIMARY KEY,
410
- name TEXT NOT NULL,
411
- account TEXT NOT NULL,
412
  contract_value INTEGER NOT NULL,
413
  term_years REAL NOT NULL,
414
  status TEXT NOT NULL,
415
  expiration_date TEXT NOT NULL,
416
- start_date TEXT NOT NULL,
417
- sector TEXT NOT NULL
418
  )
419
  """)
420
 
421
  # Generate data
422
- industries = self.INDUSTRIES["customer_f"]
423
  for i in range(1, 51):
424
  start_date, expiry_date = self.generate_dates()
425
 
@@ -434,19 +391,15 @@ class MockDataGenerator:
434
 
435
  cursor.execute("""
436
  INSERT INTO contracts
437
- (contract_id, name, account, contract_value, term_years,
438
- status, expiration_date, start_date, sector)
439
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
440
  """, (
441
  i,
442
- self.generate_contract_name(i),
443
- self.generate_company_name(),
444
  self.generate_contract_value(is_annual=True), # ANNUAL value
445
  term_years,
446
  status,
447
  expiry_date.strftime("%Y-%m-%d"),
448
- start_date.strftime("%Y-%m-%d"),
449
- random.choice(industries)
450
  ))
451
 
452
  conn.commit()
 
108
  conn = sqlite3.connect(db_path)
109
  cursor = conn.cursor()
110
 
111
+ # Create table (removed customer_name and industry - not in knowledge graph)
112
  cursor.execute("""
113
  CREATE TABLE IF NOT EXISTS contracts (
114
  contract_id INTEGER PRIMARY KEY,
 
 
115
  contract_value INTEGER NOT NULL,
116
  status TEXT NOT NULL,
117
  expiry_date TEXT NOT NULL,
118
+ start_date TEXT NOT NULL
 
119
  )
120
  """)
121
 
122
  # Generate data
 
123
  for i in range(1, 51):
124
  start_date, expiry_date = self.generate_dates()
125
 
 
131
 
132
  cursor.execute("""
133
  INSERT INTO contracts
134
+ (contract_id, contract_value, status, expiry_date, start_date)
135
+ VALUES (?, ?, ?, ?, ?)
 
136
  """, (
137
  i,
 
 
138
  self.generate_contract_value(is_annual=False),
139
  status,
140
  expiry_date.strftime("%Y-%m-%d"),
141
+ start_date.strftime("%Y-%m-%d")
 
142
  ))
143
 
144
  conn.commit()
 
151
  conn = sqlite3.connect(db_path)
152
  cursor = conn.cursor()
153
 
154
+ # Create tables (removed client_name and sector)
155
  cursor.execute("""
156
  CREATE TABLE IF NOT EXISTS contract_headers (
157
  id INTEGER PRIMARY KEY,
 
 
158
  contract_value INTEGER NOT NULL,
159
+ start_date TEXT NOT NULL
 
160
  )
161
  """)
162
 
 
180
  """)
181
 
182
  # Generate data
 
183
  for i in range(1, 51):
184
  start_date, expiry_date = self.generate_dates()
185
 
186
  # Insert header
187
  cursor.execute("""
188
  INSERT INTO contract_headers
189
+ (id, contract_value, start_date)
190
+ VALUES (?, ?, ?)
191
  """, (
192
  i,
 
 
193
  self.generate_contract_value(is_annual=False),
194
+ start_date.strftime("%Y-%m-%d")
 
195
  ))
196
 
197
  # Insert status history (1-3 status changes)
 
229
  conn = sqlite3.connect(db_path)
230
  cursor = conn.cursor()
231
 
232
+ # Create table with different column names (removed account and business_sector)
233
  cursor.execute("""
234
  CREATE TABLE IF NOT EXISTS contracts (
235
  id INTEGER PRIMARY KEY,
 
 
236
  total_value INTEGER NOT NULL,
237
  current_status TEXT NOT NULL,
238
  expiration_date TEXT NOT NULL,
239
+ inception_date TEXT NOT NULL
 
240
  )
241
  """)
242
 
243
  # Generate data
 
244
  for i in range(1, 51):
245
  start_date, expiry_date = self.generate_dates()
246
 
 
251
 
252
  cursor.execute("""
253
  INSERT INTO contracts
254
+ (id, total_value, current_status, expiration_date, inception_date)
255
+ VALUES (?, ?, ?, ?, ?)
 
256
  """, (
257
  i,
 
 
258
  self.generate_contract_value(is_annual=False),
259
  status,
260
  expiry_date.strftime("%Y-%m-%d"),
261
+ start_date.strftime("%Y-%m-%d")
 
262
  ))
263
 
264
  conn.commit()
 
271
  conn = sqlite3.connect(db_path)
272
  cursor = conn.cursor()
273
 
274
+ # Create table with days_remaining instead of date (removed customer_org and industry)
275
  cursor.execute("""
276
  CREATE TABLE IF NOT EXISTS contracts (
277
  contract_id INTEGER PRIMARY KEY,
 
 
278
  contract_value INTEGER NOT NULL,
279
  status TEXT NOT NULL,
280
  days_remaining INTEGER NOT NULL,
281
+ start_date TEXT NOT NULL
 
282
  )
283
  """)
284
 
285
  # Generate data
 
286
  for i in range(1, 51):
287
  start_date, expiry_date = self.generate_dates()
288
 
 
296
 
297
  cursor.execute("""
298
  INSERT INTO contracts
299
+ (contract_id, contract_value, status, days_remaining, start_date)
300
+ VALUES (?, ?, ?, ?, ?)
 
301
  """, (
302
  i,
 
 
303
  self.generate_contract_value(is_annual=False),
304
  status,
305
  days_remaining,
306
+ start_date.strftime("%Y-%m-%d")
 
307
  ))
308
 
309
  conn.commit()
 
316
  conn = sqlite3.connect(db_path)
317
  cursor = conn.cursor()
318
 
319
+ # Create table (removed customer_name and industry)
320
  cursor.execute("""
321
  CREATE TABLE IF NOT EXISTS contracts (
322
  contract_id INTEGER PRIMARY KEY,
 
 
323
  contract_value INTEGER NOT NULL,
324
  term_years REAL NOT NULL,
325
  status TEXT NOT NULL,
326
  expiry_date TEXT NOT NULL,
327
+ start_date TEXT NOT NULL
 
328
  )
329
  """)
330
 
331
  # Generate data
 
332
  for i in range(1, 51):
333
  start_date, expiry_date = self.generate_dates()
334
 
 
343
 
344
  cursor.execute("""
345
  INSERT INTO contracts
346
+ (contract_id, contract_value, term_years, status, expiry_date, start_date)
347
+ VALUES (?, ?, ?, ?, ?, ?)
 
348
  """, (
349
  i,
 
 
350
  self.generate_contract_value(is_annual=False),
351
  term_years,
352
  status,
353
  expiry_date.strftime("%Y-%m-%d"),
354
+ start_date.strftime("%Y-%m-%d")
 
355
  ))
356
 
357
  conn.commit()
 
364
  conn = sqlite3.connect(db_path)
365
  cursor = conn.cursor()
366
 
367
+ # Create table (removed account and sector)
368
  cursor.execute("""
369
  CREATE TABLE IF NOT EXISTS contracts (
370
  contract_id INTEGER PRIMARY KEY,
 
 
371
  contract_value INTEGER NOT NULL,
372
  term_years REAL NOT NULL,
373
  status TEXT NOT NULL,
374
  expiration_date TEXT NOT NULL,
375
+ start_date TEXT NOT NULL
 
376
  )
377
  """)
378
 
379
  # Generate data
 
380
  for i in range(1, 51):
381
  start_date, expiry_date = self.generate_dates()
382
 
 
391
 
392
  cursor.execute("""
393
  INSERT INTO contracts
394
+ (contract_id, contract_value, term_years, status, expiration_date, start_date)
395
+ VALUES (?, ?, ?, ?, ?, ?)
 
396
  """, (
397
  i,
 
 
398
  self.generate_contract_value(is_annual=True), # ANNUAL value
399
  term_years,
400
  status,
401
  expiry_date.strftime("%Y-%m-%d"),
402
+ start_date.strftime("%Y-%m-%d")
 
403
  ))
404
 
405
  conn.commit()
schema_translator/models.py CHANGED
@@ -90,8 +90,8 @@ class SchemaTable(BaseModel):
90
 
91
  class CustomerSchema(BaseModel):
92
  """Represents the complete schema for a customer database."""
93
- customer_id: str = Field(..., description="Unique customer identifier")
94
- customer_name: str = Field(..., description="Customer display name")
95
  tables: List[SchemaTable] = Field(..., description="Tables in this schema")
96
  semantic_notes: Dict[str, str] = Field(
97
  default_factory=dict,
@@ -170,6 +170,10 @@ class SemanticQueryPlan(BaseModel):
170
  description="Ordering (concept, direction) pairs"
171
  )
172
  limit: Optional[int] = Field(None, description="Maximum number of results")
 
 
 
 
173
 
174
  model_config = {"use_enum_values": True}
175
 
 
90
 
91
  class CustomerSchema(BaseModel):
92
  """Represents the complete schema for a customer database."""
93
+ customer_id: str = Field(..., description="Unique customer identifier (e.g., customer_a)")
94
+ customer_name: Optional[str] = Field(None, description="Optional customer display name")
95
  tables: List[SchemaTable] = Field(..., description="Tables in this schema")
96
  semantic_notes: Dict[str, str] = Field(
97
  default_factory=dict,
 
170
  description="Ordering (concept, direction) pairs"
171
  )
172
  limit: Optional[int] = Field(None, description="Maximum number of results")
173
+ target_customers: Optional[List[str]] = Field(
174
+ None,
175
+ description="Specific customer databases to query (e.g., ['customer_a', 'customer_b']). None means all customers."
176
+ )
177
 
178
  model_config = {"use_enum_values": True}
179
 
schema_translator/orchestrator.py CHANGED
@@ -122,14 +122,21 @@ class ChatOrchestrator:
122
  if debug:
123
  logger.info(f"Semantic plan: {semantic_plan}")
124
 
125
- # Step 3: Execute query across customers
126
- logger.info(f"Executing query across {len(customer_ids) if customer_ids else 'all'} customers...")
 
 
 
 
 
 
 
127
  result = self.result_harmonizer.execute_across_customers(
128
  semantic_plan,
129
- customer_ids=customer_ids
130
  )
131
 
132
- # Step 4: Calculate total execution time
133
  total_time_ms = (time.time() - start_time) * 1000
134
 
135
  logger.info(
@@ -138,7 +145,7 @@ class ChatOrchestrator:
138
  f"{total_time_ms:.2f}ms"
139
  )
140
 
141
- # Step 5: Add to history
142
  self._add_to_history(
143
  query_text=query_text,
144
  semantic_plan=semantic_plan,
@@ -147,7 +154,7 @@ class ChatOrchestrator:
147
  error=None
148
  )
149
 
150
- # Step 6: Build response
151
  response = {
152
  "success": True,
153
  "query_text": query_text,
 
122
  if debug:
123
  logger.info(f"Semantic plan: {semantic_plan}")
124
 
125
+ # Step 3: Determine which customers to query
126
+ # Priority: explicit customer_ids parameter > target_customers from query > all customers
127
+ target_customers = customer_ids
128
+ if target_customers is None and semantic_plan.target_customers:
129
+ target_customers = semantic_plan.target_customers
130
+ logger.info(f"Extracted target customers from query: {target_customers}")
131
+
132
+ # Step 4: Execute query across customers
133
+ logger.info(f"Executing query across {len(target_customers) if target_customers else 'all'} customers...")
134
  result = self.result_harmonizer.execute_across_customers(
135
  semantic_plan,
136
+ customer_ids=target_customers
137
  )
138
 
139
+ # Step 5: Calculate total execution time
140
  total_time_ms = (time.time() - start_time) * 1000
141
 
142
  logger.info(
 
145
  f"{total_time_ms:.2f}ms"
146
  )
147
 
148
+ # Step 6: Add to history
149
  self._add_to_history(
150
  query_text=query_text,
151
  semantic_plan=semantic_plan,
 
154
  error=None
155
  )
156
 
157
+ # Step 7: Build response
158
  response = {
159
  "success": True,
160
  "query_text": query_text,
schema_translator/result_harmonizer.py CHANGED
@@ -687,6 +687,7 @@ class ResultHarmonizer:
687
  """
688
  harmonized = {}
689
 
 
690
  for customer_field, concept_id in field_mappings.items():
691
  if customer_field in row:
692
  value = row[customer_field]
@@ -702,4 +703,12 @@ class ResultHarmonizer:
702
  # Field not present in row
703
  harmonized[concept_id] = None
704
 
 
 
 
 
 
 
 
 
705
  return harmonized
 
687
  """
688
  harmonized = {}
689
 
690
+ # First, map all fields that have concept mappings
691
  for customer_field, concept_id in field_mappings.items():
692
  if customer_field in row:
693
  value = row[customer_field]
 
703
  # Field not present in row
704
  harmonized[concept_id] = None
705
 
706
+ # Second, include any unmapped columns (like aggregation results)
707
+ # These are columns like count_contract_identifier, sum_contract_value, etc.
708
+ for field_name, value in row.items():
709
+ if field_name not in field_mappings:
710
+ # This is an unmapped column (likely an aggregation)
711
+ # Include it as-is with its original name
712
+ harmonized[field_name] = value
713
+
714
  return harmonized
tests/test_agents.py CHANGED
@@ -98,11 +98,11 @@ class TestQueryUnderstandingAgent:
98
 
99
  def test_multi_filter_query(self, query_agent):
100
  """Test parsing a query with multiple filters."""
101
- query = "List technology contracts expiring in 2026"
102
  plan = query_agent.understand_query(query)
103
 
104
  assert plan.intent == QueryIntent.FIND_CONTRACTS
105
- assert len(plan.filters) >= 2
106
 
107
  def test_explain_query_plan(self, query_agent):
108
  """Test generating human-readable explanations."""
@@ -122,7 +122,6 @@ class TestSchemaAnalyzerAgent:
122
  # Create a simple test schema
123
  schema = CustomerSchema(
124
  customer_id="test_customer",
125
- customer_name="Test Customer",
126
  tables=[
127
  SchemaTable(
128
  name="contracts",
@@ -173,7 +172,6 @@ class TestSchemaAnalyzerAgent:
173
 
174
  schema = CustomerSchema(
175
  customer_id="test_customer",
176
- customer_name="Test Customer",
177
  tables=[
178
  SchemaTable(
179
  name="contracts",
@@ -301,7 +299,6 @@ class TestEndToEndAgent:
301
  # Create a test schema
302
  schema = CustomerSchema(
303
  customer_id="test_new_customer",
304
- customer_name="Test New Customer",
305
  tables=[
306
  SchemaTable(
307
  name="agreements",
 
98
 
99
  def test_multi_filter_query(self, query_agent):
100
  """Test parsing a query with multiple filters."""
101
+ query = "List active contracts worth more than 1 million expiring in 2026"
102
  plan = query_agent.understand_query(query)
103
 
104
  assert plan.intent == QueryIntent.FIND_CONTRACTS
105
+ assert len(plan.filters) >= 2 # Should have status + value + date filters
106
 
107
  def test_explain_query_plan(self, query_agent):
108
  """Test generating human-readable explanations."""
 
122
  # Create a simple test schema
123
  schema = CustomerSchema(
124
  customer_id="test_customer",
 
125
  tables=[
126
  SchemaTable(
127
  name="contracts",
 
172
 
173
  schema = CustomerSchema(
174
  customer_id="test_customer",
 
175
  tables=[
176
  SchemaTable(
177
  name="contracts",
 
299
  # Create a test schema
300
  schema = CustomerSchema(
301
  customer_id="test_new_customer",
 
302
  tables=[
303
  SchemaTable(
304
  name="agreements",
tests/test_knowledge_graph.py CHANGED
@@ -269,12 +269,13 @@ class TestLoadedKnowledgeGraph:
269
  kg = SchemaKnowledgeGraph()
270
  kg.load()
271
 
272
- # Check concepts exist (8 concepts including contract_name)
273
- assert len(kg.concepts) == 8
274
  assert kg.get_concept("contract_identifier") is not None
275
- assert kg.get_concept("contract_name") is not None
276
  assert kg.get_concept("contract_value") is not None
277
  assert kg.get_concept("contract_expiration") is not None
 
 
278
 
279
  def test_all_customers_mapped(self):
280
  """Test that all 6 customers have mappings."""
 
269
  kg = SchemaKnowledgeGraph()
270
  kg.load()
271
 
272
+ # Check concepts exist (5 concepts, removed customer_name and industry_sector)
273
+ assert len(kg.concepts) == 5
274
  assert kg.get_concept("contract_identifier") is not None
 
275
  assert kg.get_concept("contract_value") is not None
276
  assert kg.get_concept("contract_expiration") is not None
277
+ assert kg.get_concept("contract_status") is not None
278
+ assert kg.get_concept("contract_start") is not None
279
 
280
  def test_all_customers_mapped(self):
281
  """Test that all 6 customers have mappings."""
tests/test_models.py CHANGED
@@ -89,7 +89,6 @@ class TestSchemaModels:
89
 
90
  schema = CustomerSchema(
91
  customer_id="customer_a",
92
- customer_name="Customer A",
93
  tables=[table],
94
  semantic_notes={"contract_value": "lifetime total"}
95
  )
 
89
 
90
  schema = CustomerSchema(
91
  customer_id="customer_a",
 
92
  tables=[table],
93
  semantic_notes={"contract_value": "lifetime total"}
94
  )
tests/test_query_execution.py CHANGED
@@ -151,16 +151,16 @@ class TestQueryCompiler:
151
 
152
  assert "LIMIT 10" in sql
153
 
154
- def test_industry_filter(self, compiler):
155
- """Test filtering by industry."""
156
  plan = SemanticQueryPlan(
157
  intent=QueryIntent.FIND_CONTRACTS,
158
- projections=["contract_identifier", "industry_sector"],
159
  filters=[
160
  QueryFilter(
161
- concept="industry_sector",
162
  operator=QueryOperator.EQUALS,
163
- value="Technology"
164
  )
165
  ]
166
  )
@@ -168,8 +168,8 @@ class TestQueryCompiler:
168
  sql = compiler.compile_for_customer(plan, "customer_a")
169
 
170
  assert "WHERE" in sql
171
- assert "industry" in sql
172
- assert "Technology" in sql
173
 
174
 
175
  class TestDatabaseExecutor:
@@ -182,7 +182,7 @@ class TestDatabaseExecutor:
182
 
183
  def test_simple_query(self, executor):
184
  """Test executing a simple query."""
185
- sql = "SELECT contract_id, contract_name FROM contracts LIMIT 5"
186
  result = executor.execute_query("customer_a", sql)
187
 
188
  assert result.success
@@ -197,7 +197,7 @@ class TestDatabaseExecutor:
197
 
198
  assert result.success
199
  assert result.row_count == 1
200
- assert result.data[0]["count"] == 100
201
 
202
  def test_invalid_sql(self, executor):
203
  """Test handling of invalid SQL."""
@@ -223,12 +223,12 @@ class TestDatabaseExecutor:
223
  def test_count_rows(self, executor):
224
  """Test counting rows in a table."""
225
  count = executor.count_rows("customer_a", "contracts")
226
- assert count == 100
227
 
228
  def test_customer_b_multi_table(self, executor):
229
  """Test querying Customer B's multi-table schema."""
230
  sql = """
231
- SELECT h.id, h.contract_name, r.renewal_date
232
  FROM contract_headers AS h
233
  JOIN renewal_schedule AS r ON h.id = r.contract_id
234
  LIMIT 5
@@ -262,7 +262,7 @@ class TestIntegration:
262
  # Create query plan
263
  plan = SemanticQueryPlan(
264
  intent=QueryIntent.FIND_CONTRACTS,
265
- projections=["contract_identifier", "contract_value", "customer_name"],
266
  filters=[],
267
  limit=10
268
  )
@@ -355,7 +355,7 @@ class TestIntegration:
355
  """Test executing the same semantic query across all customers."""
356
  plan = SemanticQueryPlan(
357
  intent=QueryIntent.FIND_CONTRACTS,
358
- projections=["contract_identifier", "contract_value", "customer_name"],
359
  filters=[],
360
  limit=5
361
  )
 
151
 
152
  assert "LIMIT 10" in sql
153
 
154
+ def test_status_filter(self, compiler):
155
+ """Test filtering by contract status."""
156
  plan = SemanticQueryPlan(
157
  intent=QueryIntent.FIND_CONTRACTS,
158
+ projections=["contract_identifier", "contract_status"],
159
  filters=[
160
  QueryFilter(
161
+ concept="contract_status",
162
  operator=QueryOperator.EQUALS,
163
+ value="active"
164
  )
165
  ]
166
  )
 
168
  sql = compiler.compile_for_customer(plan, "customer_a")
169
 
170
  assert "WHERE" in sql
171
+ assert "status" in sql
172
+ assert "active" in sql
173
 
174
 
175
  class TestDatabaseExecutor:
 
182
 
183
  def test_simple_query(self, executor):
184
  """Test executing a simple query."""
185
+ sql = "SELECT contract_id, contract_value FROM contracts LIMIT 5"
186
  result = executor.execute_query("customer_a", sql)
187
 
188
  assert result.success
 
197
 
198
  assert result.success
199
  assert result.row_count == 1
200
+ assert result.data[0]["count"] == 50 # 50 contracts per database
201
 
202
  def test_invalid_sql(self, executor):
203
  """Test handling of invalid SQL."""
 
223
  def test_count_rows(self, executor):
224
  """Test counting rows in a table."""
225
  count = executor.count_rows("customer_a", "contracts")
226
+ assert count == 50 # 50 contracts per database
227
 
228
  def test_customer_b_multi_table(self, executor):
229
  """Test querying Customer B's multi-table schema."""
230
  sql = """
231
+ SELECT h.id, h.contract_value, r.renewal_date
232
  FROM contract_headers AS h
233
  JOIN renewal_schedule AS r ON h.id = r.contract_id
234
  LIMIT 5
 
262
  # Create query plan
263
  plan = SemanticQueryPlan(
264
  intent=QueryIntent.FIND_CONTRACTS,
265
+ projections=["contract_identifier", "contract_value", "contract_status"],
266
  filters=[],
267
  limit=10
268
  )
 
355
  """Test executing the same semantic query across all customers."""
356
  plan = SemanticQueryPlan(
357
  intent=QueryIntent.FIND_CONTRACTS,
358
+ projections=["contract_identifier", "contract_value", "contract_status"],
359
  filters=[],
360
  limit=5
361
  )
tests/test_result_harmonization.py CHANGED
@@ -157,21 +157,21 @@ class TestValueHarmonizer:
157
  assert harmonized["contract_status"] == "Active"
158
  assert harmonized["contract_value"] == 100000.0
159
 
160
- def test_harmonize_row_with_industry(self, result_harmonizer):
161
- """Test harmonizing a row with industry normalization."""
162
  row = {
163
  "contract_id": "A001",
164
- "industry": "tech"
165
  }
166
 
167
  field_mappings = {
168
  "contract_id": "contract_identifier",
169
- "industry": "industry_sector"
170
  }
171
 
172
  harmonized = result_harmonizer._harmonize_row(row, "customer_a", field_mappings)
173
 
174
- assert harmonized["industry_sector"] == "Technology"
175
 
176
 
177
  class TestResultHarmonizer:
@@ -393,20 +393,20 @@ class TestResultHarmonizer:
393
  assert all(row.data["contract_status"] == "Active" for row in filtered_result.results)
394
 
395
  def test_aggregate_results_count(self, result_harmonizer):
396
- """Test aggregating results with count."""
397
  harmonized = HarmonizedResult(
398
  results=[
399
  HarmonizedRow(
400
  customer_id="customer_a",
401
- data={"industry_sector": "Technology", "contract_value": 100000}
402
  ),
403
  HarmonizedRow(
404
  customer_id="customer_a",
405
- data={"industry_sector": "Technology", "contract_value": 200000}
406
  ),
407
  HarmonizedRow(
408
  customer_id="customer_a",
409
- data={"industry_sector": "Healthcare", "contract_value": 150000}
410
  ),
411
  ],
412
  total_count=3,
@@ -417,19 +417,19 @@ class TestResultHarmonizer:
417
 
418
  aggregated = result_harmonizer.aggregate_results(
419
  harmonized,
420
- group_by=["industry_sector"],
421
  aggregations={"contract_value": "count"}
422
  )
423
 
424
- assert aggregated.total_count == 2 # Two industries
425
 
426
- # Find Technology group
427
- tech_row = next(
428
- (r for r in aggregated.results if r.data.get("industry_sector") == "Technology"),
429
  None
430
  )
431
- assert tech_row is not None
432
- assert tech_row.data["contract_value_count"] == 2
433
 
434
  def test_aggregate_results_sum(self, result_harmonizer):
435
  """Test aggregating results with sum."""
@@ -437,11 +437,11 @@ class TestResultHarmonizer:
437
  results=[
438
  HarmonizedRow(
439
  customer_id="customer_a",
440
- data={"industry_sector": "Technology", "contract_value": 100000}
441
  ),
442
  HarmonizedRow(
443
  customer_id="customer_a",
444
- data={"industry_sector": "Technology", "contract_value": 200000}
445
  ),
446
  ],
447
  total_count=2,
@@ -452,7 +452,7 @@ class TestResultHarmonizer:
452
 
453
  aggregated = result_harmonizer.aggregate_results(
454
  harmonized,
455
- group_by=["industry_sector"],
456
  aggregations={"contract_value": "sum"}
457
  )
458
 
@@ -465,11 +465,11 @@ class TestResultHarmonizer:
465
  results=[
466
  HarmonizedRow(
467
  customer_id="customer_a",
468
- data={"industry_sector": "Technology", "contract_value": 100000}
469
  ),
470
  HarmonizedRow(
471
  customer_id="customer_a",
472
- data={"industry_sector": "Technology", "contract_value": 200000}
473
  ),
474
  ],
475
  total_count=2,
@@ -480,7 +480,7 @@ class TestResultHarmonizer:
480
 
481
  aggregated = result_harmonizer.aggregate_results(
482
  harmonized,
483
- group_by=["industry_sector"],
484
  aggregations={"contract_value": "avg"}
485
  )
486
 
 
157
  assert harmonized["contract_status"] == "Active"
158
  assert harmonized["contract_value"] == 100000.0
159
 
160
+ def test_harmonize_row_with_status(self, result_harmonizer):
161
+ """Test harmonizing a row with status normalization."""
162
  row = {
163
  "contract_id": "A001",
164
+ "status": "active"
165
  }
166
 
167
  field_mappings = {
168
  "contract_id": "contract_identifier",
169
+ "status": "contract_status"
170
  }
171
 
172
  harmonized = result_harmonizer._harmonize_row(row, "customer_a", field_mappings)
173
 
174
+ assert harmonized["contract_status"] == "active"
175
 
176
 
177
  class TestResultHarmonizer:
 
393
  assert all(row.data["contract_status"] == "Active" for row in filtered_result.results)
394
 
395
  def test_aggregate_results_count(self, result_harmonizer):
396
+ """Test aggregating results with count by status."""
397
  harmonized = HarmonizedResult(
398
  results=[
399
  HarmonizedRow(
400
  customer_id="customer_a",
401
+ data={"contract_status": "active", "contract_value": 100000}
402
  ),
403
  HarmonizedRow(
404
  customer_id="customer_a",
405
+ data={"contract_status": "active", "contract_value": 200000}
406
  ),
407
  HarmonizedRow(
408
  customer_id="customer_a",
409
+ data={"contract_status": "inactive", "contract_value": 150000}
410
  ),
411
  ],
412
  total_count=3,
 
417
 
418
  aggregated = result_harmonizer.aggregate_results(
419
  harmonized,
420
+ group_by=["contract_status"],
421
  aggregations={"contract_value": "count"}
422
  )
423
 
424
+ assert aggregated.total_count == 2 # Two statuses
425
 
426
+ # Find active group
427
+ active_row = next(
428
+ (r for r in aggregated.results if r.data.get("contract_status") == "active"),
429
  None
430
  )
431
+ assert active_row is not None
432
+ assert active_row.data["contract_value_count"] == 2
433
 
434
  def test_aggregate_results_sum(self, result_harmonizer):
435
  """Test aggregating results with sum."""
 
437
  results=[
438
  HarmonizedRow(
439
  customer_id="customer_a",
440
+ data={"contract_status": "active", "contract_value": 100000}
441
  ),
442
  HarmonizedRow(
443
  customer_id="customer_a",
444
+ data={"contract_status": "active", "contract_value": 200000}
445
  ),
446
  ],
447
  total_count=2,
 
452
 
453
  aggregated = result_harmonizer.aggregate_results(
454
  harmonized,
455
+ group_by=["contract_status"],
456
  aggregations={"contract_value": "sum"}
457
  )
458
 
 
465
  results=[
466
  HarmonizedRow(
467
  customer_id="customer_a",
468
+ data={"contract_status": "active", "contract_value": 100000}
469
  ),
470
  HarmonizedRow(
471
  customer_id="customer_a",
472
+ data={"contract_status": "active", "contract_value": 200000}
473
  ),
474
  ],
475
  total_count=2,
 
480
 
481
  aggregated = result_harmonizer.aggregate_results(
482
  harmonized,
483
+ group_by=["contract_status"],
484
  aggregations={"contract_value": "avg"}
485
  )
486