schema-translator / schema_translator /result_harmonizer.py
sanzgiri's picture
v2.0: Dynamic result formatting and natural language customer selection
4bb196e
"""Result harmonization for combining and normalizing multi-customer query results."""
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from schema_translator.database_executor import DatabaseExecutor
from schema_translator.knowledge_graph import SchemaKnowledgeGraph
from schema_translator.models import (
HarmonizedResult,
HarmonizedRow,
NormalizedValue,
QueryResult,
SemanticQueryPlan,
SemanticType,
)
from schema_translator.query_compiler import QueryCompiler
class ResultHarmonizer:
"""Harmonizes query results across multiple customer databases."""
def __init__(
self,
knowledge_graph: SchemaKnowledgeGraph,
executor: Optional[DatabaseExecutor] = None
):
"""Initialize the result harmonizer.
Args:
knowledge_graph: Knowledge graph with concept mappings
executor: Database executor (creates new one if not provided)
"""
self.knowledge_graph = knowledge_graph
self.executor = executor or DatabaseExecutor()
self.compiler = QueryCompiler(knowledge_graph)
def execute_across_customers(
self,
query_plan: SemanticQueryPlan,
customer_ids: Optional[List[str]] = None,
parallel: bool = True
) -> HarmonizedResult:
"""Execute a semantic query across multiple customers and harmonize results.
Args:
query_plan: Semantic query plan to execute
customer_ids: List of customer IDs to query (all if None)
parallel: Whether to execute queries in parallel
Returns:
HarmonizedResult with combined and normalized data
"""
start_time = time.time()
# Determine which customers to query
if customer_ids is None:
# Get all customer IDs from database directory
db_dir = self.executor.config.database_dir
db_files = list(db_dir.glob("customer_*.db"))
customer_ids = [f.stem for f in db_files]
# Execute queries for each customer
if parallel and len(customer_ids) > 1:
results = self._execute_parallel(query_plan, customer_ids)
else:
results = self._execute_sequential(query_plan, customer_ids)
# Harmonize results
harmonized = self._harmonize_results(query_plan, results)
# Calculate execution time
execution_time_ms = (time.time() - start_time) * 1000
harmonized.execution_time_ms = execution_time_ms
return harmonized
def _execute_parallel(
self,
query_plan: SemanticQueryPlan,
customer_ids: List[str]
) -> Dict[str, QueryResult]:
"""Execute queries in parallel across customers.
Args:
query_plan: Semantic query plan
customer_ids: List of customer IDs
Returns:
Map of customer_id to QueryResult
"""
results = {}
with ThreadPoolExecutor(max_workers=min(len(customer_ids), 10)) as executor:
# Submit all tasks
future_to_customer = {
executor.submit(self._execute_for_customer, query_plan, cid): cid
for cid in customer_ids
}
# Collect results as they complete
for future in as_completed(future_to_customer):
customer_id = future_to_customer[future]
try:
results[customer_id] = future.result()
except Exception as e:
# Create error result
results[customer_id] = QueryResult(
customer_id=customer_id,
data=[],
sql_executed="",
execution_time_ms=0,
row_count=0,
error=str(e)
)
return results
def _execute_sequential(
self,
query_plan: SemanticQueryPlan,
customer_ids: List[str]
) -> Dict[str, QueryResult]:
"""Execute queries sequentially across customers.
Args:
query_plan: Semantic query plan
customer_ids: List of customer IDs
Returns:
Map of customer_id to QueryResult
"""
results = {}
for customer_id in customer_ids:
try:
results[customer_id] = self._execute_for_customer(query_plan, customer_id)
except Exception as e:
# Create error result
results[customer_id] = QueryResult(
customer_id=customer_id,
data=[],
sql_executed="",
execution_time_ms=0,
row_count=0,
error=str(e)
)
return results
def _execute_for_customer(
self,
query_plan: SemanticQueryPlan,
customer_id: str
) -> QueryResult:
"""Execute query for a single customer.
Args:
query_plan: Semantic query plan
customer_id: Customer ID
Returns:
QueryResult
"""
# Compile query for this customer
sql = self.compiler.compile_for_customer(query_plan, customer_id)
# Execute query
result = self.executor.execute_query(customer_id, sql)
return result
def _harmonize_results(
self,
query_plan: SemanticQueryPlan,
results: Dict[str, QueryResult]
) -> HarmonizedResult:
"""Harmonize results from multiple customers.
Args:
query_plan: Original semantic query plan
results: Map of customer_id to QueryResult
Returns:
HarmonizedResult with unified data
"""
harmonized_rows = []
customers_succeeded = []
customers_failed = []
errors = {}
# Determine field mappings from query plan
concepts = self._get_concepts_from_plan(query_plan)
for customer_id, result in results.items():
if result.error:
customers_failed.append(customer_id)
errors[customer_id] = result.error
continue
customers_succeeded.append(customer_id)
# Get field mappings for this customer
field_mappings = self._build_field_mappings(customer_id, concepts)
# Harmonize each row
for row in result.data:
harmonized_data = self._harmonize_row(
row, customer_id, field_mappings
)
harmonized_rows.append(
HarmonizedRow(
customer_id=customer_id,
data=harmonized_data,
metadata={
"original_row": row,
"sql_executed": result.sql_executed
}
)
)
return HarmonizedResult(
results=harmonized_rows,
total_count=len(harmonized_rows),
customers_queried=list(results.keys()),
customers_succeeded=customers_succeeded,
customers_failed=customers_failed,
errors=errors,
execution_time_ms=0 # Will be set by caller
)
def _get_concepts_from_plan(self, query_plan: SemanticQueryPlan) -> List[str]:
"""Extract concept IDs from query plan.
Args:
query_plan: Semantic query plan
Returns:
List of concept IDs
"""
concepts = set()
# Add projected concepts
if query_plan.projections:
concepts.update(query_plan.projections)
else:
# If no projections specified (SELECT *), include all available concepts
# This ensures all fields are harmonized when user asks for "all" data
all_concepts = self.knowledge_graph.get_all_concepts()
concepts.update([c.concept_id for c in all_concepts])
# Add filtered concepts
if query_plan.filters:
for filter_obj in query_plan.filters:
concepts.add(filter_obj.concept)
# Add aggregated concepts
if query_plan.aggregations:
for agg in query_plan.aggregations:
concepts.add(agg.concept)
return list(concepts)
def _build_field_mappings(
self,
customer_id: str,
concepts: List[str]
) -> Dict[str, str]:
"""Build mapping from customer field names to concept IDs.
Args:
customer_id: Customer ID
concepts: List of concept IDs to map
Returns:
Map of customer field name to concept ID
"""
field_mappings = {}
for concept_id in concepts:
mapping = self.knowledge_graph.get_mapping(concept_id, customer_id)
if mapping:
field_mappings[mapping.column_name] = concept_id
return field_mappings
def aggregate_results(
self,
harmonized_result: HarmonizedResult,
group_by: Optional[List[str]] = None,
aggregations: Optional[Dict[str, str]] = None
) -> HarmonizedResult:
"""Perform additional aggregation on harmonized results.
Args:
harmonized_result: Harmonized results to aggregate
group_by: List of fields to group by
aggregations: Map of field to aggregation function (count, sum, avg, etc.)
Returns:
New HarmonizedResult with aggregated data
"""
if not group_by and not aggregations:
return harmonized_result
# Group rows
groups: Dict[tuple, List[HarmonizedRow]] = {}
for row in harmonized_result.results:
if group_by:
key = tuple(row.data.get(field) for field in group_by)
else:
key = () # Single group
if key not in groups:
groups[key] = []
groups[key].append(row)
# Aggregate each group
aggregated_rows = []
for key, group_rows in groups.items():
aggregated_data = {}
# Add group_by fields
if group_by:
for i, field in enumerate(group_by):
aggregated_data[field] = key[i]
# Apply aggregations
if aggregations:
for field, agg_func in aggregations.items():
values = [
row.data.get(field)
for row in group_rows
if row.data.get(field) is not None
]
if agg_func == "count":
aggregated_data[f"{field}_count"] = len(values)
elif agg_func == "sum":
aggregated_data[f"{field}_sum"] = sum(values) if values else None
elif agg_func == "avg":
aggregated_data[f"{field}_avg"] = (
sum(values) / len(values) if values else None
)
elif agg_func == "min":
aggregated_data[f"{field}_min"] = min(values) if values else None
elif agg_func == "max":
aggregated_data[f"{field}_max"] = max(values) if values else None
# Create aggregated row (use first customer_id from group)
aggregated_rows.append(
HarmonizedRow(
customer_id="aggregated",
data=aggregated_data,
metadata={
"row_count": len(group_rows),
"customers": list(set(r.customer_id for r in group_rows))
}
)
)
return HarmonizedResult(
results=aggregated_rows,
total_count=len(aggregated_rows),
customers_queried=harmonized_result.customers_queried,
customers_succeeded=harmonized_result.customers_succeeded,
customers_failed=harmonized_result.customers_failed,
errors=harmonized_result.errors,
execution_time_ms=harmonized_result.execution_time_ms
)
def sort_results(
self,
harmonized_result: HarmonizedResult,
sort_by: str,
descending: bool = False
) -> HarmonizedResult:
"""Sort harmonized results by a field.
Args:
harmonized_result: Results to sort
sort_by: Field name to sort by
descending: Whether to sort descending
Returns:
New HarmonizedResult with sorted data
"""
sorted_rows = sorted(
harmonized_result.results,
key=lambda r: r.data.get(sort_by, ""),
reverse=descending
)
return HarmonizedResult(
results=sorted_rows,
total_count=harmonized_result.total_count,
customers_queried=harmonized_result.customers_queried,
customers_succeeded=harmonized_result.customers_succeeded,
customers_failed=harmonized_result.customers_failed,
errors=harmonized_result.errors,
execution_time_ms=harmonized_result.execution_time_ms
)
def filter_results(
self,
harmonized_result: HarmonizedResult,
filter_func
) -> HarmonizedResult:
"""Filter harmonized results using a custom function.
Args:
harmonized_result: Results to filter
filter_func: Function that takes a HarmonizedRow and returns bool
Returns:
New HarmonizedResult with filtered data
"""
filtered_rows = [
row for row in harmonized_result.results
if filter_func(row)
]
return HarmonizedResult(
results=filtered_rows,
total_count=len(filtered_rows),
customers_queried=harmonized_result.customers_queried,
customers_succeeded=harmonized_result.customers_succeeded,
customers_failed=harmonized_result.customers_failed,
errors=harmonized_result.errors,
execution_time_ms=harmonized_result.execution_time_ms
)
# Value normalization methods (formerly ValueHarmonizer)
def _normalize_value(
self,
value: Any,
customer_id: str,
concept_id: str,
target_type: Optional[SemanticType] = None
) -> NormalizedValue:
"""Normalize a value from a customer schema to a common format.
Args:
value: The value to normalize
customer_id: Customer ID for context
concept_id: The semantic concept this value represents
target_type: Optional target semantic type to convert to
Returns:
NormalizedValue with original and normalized forms
"""
# Get concept mapping for this customer
mapping = self.knowledge_graph.get_mapping(concept_id, customer_id)
if not mapping:
# No mapping found, return as-is
return NormalizedValue(
original_value=value,
normalized_value=value,
original_type="unknown",
normalized_type="unknown",
transformation_applied=None
)
# Handle semantic_type as either SemanticType enum or string
if isinstance(mapping.semantic_type, SemanticType):
original_type = mapping.semantic_type.value
semantic_type_enum = mapping.semantic_type
else:
original_type = str(mapping.semantic_type)
semantic_type_enum = SemanticType(mapping.semantic_type)
transformation = mapping.transformation
# Apply transformation if specified
if transformation:
normalized = self._apply_transformation(
value, transformation, customer_id, concept_id
)
transformation_applied = transformation
else:
normalized = value
transformation_applied = None
# Convert type if target specified
if target_type and target_type != semantic_type_enum:
normalized = self._convert_type(normalized, semantic_type_enum, target_type)
if transformation_applied:
transformation_applied += f" + type_conversion_to_{target_type.value}"
else:
transformation_applied = f"type_conversion_to_{target_type.value}"
return NormalizedValue(
original_value=value,
normalized_value=normalized,
original_type=original_type,
normalized_type=target_type.value if target_type else original_type,
transformation_applied=transformation_applied
)
def _apply_transformation(
self,
value: Any,
transformation: str,
customer_id: str,
concept_id: str
) -> Any:
"""Apply a transformation to a value.
Args:
value: Value to transform
transformation: Transformation SQL or expression
customer_id: Customer ID for context
concept_id: Concept ID for context
Returns:
Transformed value
"""
# Handle days_remaining -> end_date conversion
if "CURRENT_DATE" in transformation or "julianday" in transformation:
return self._days_to_date(value)
# Handle annual_value -> lifetime_value conversion
if "contract_length" in transformation or "*" in transformation:
# Need to get contract_length for this specific row
# For now, apply a default multiplier (this should be done at query time)
# In real implementation, this would need row-level context
return value # Return as-is; transformation happens at SQL level
# Other transformations
return value
def _days_to_date(self, days_remaining: Any) -> Optional[str]:
"""Convert days remaining to an end date.
Args:
days_remaining: Number of days remaining
Returns:
ISO format date string or None if invalid
"""
if days_remaining is None:
return None
try:
days = int(days_remaining)
end_date = datetime.now() + timedelta(days=days)
return end_date.strftime("%Y-%m-%d")
except (ValueError, TypeError):
return None
def _convert_type(
self,
value: Any,
from_type: SemanticType,
to_type: SemanticType
) -> Any:
"""Convert a value from one semantic type to another.
Args:
value: Value to convert
from_type: Current semantic type
to_type: Target semantic type
Returns:
Converted value
"""
if value is None:
return None
# Date conversions
if to_type == SemanticType.DATE:
if from_type == SemanticType.INTEGER:
# Assume integer is days remaining
return self._days_to_date(value)
elif from_type == SemanticType.TEXT:
# Parse text date
try:
dt = datetime.fromisoformat(str(value))
return dt.strftime("%Y-%m-%d")
except (ValueError, TypeError):
return str(value)
# Numeric conversions
if to_type == SemanticType.FLOAT:
if from_type in (SemanticType.INTEGER, SemanticType.TEXT):
try:
return float(value)
except (ValueError, TypeError):
return value
if to_type == SemanticType.INTEGER:
if from_type in (SemanticType.FLOAT, SemanticType.TEXT):
try:
return int(float(value))
except (ValueError, TypeError):
return value
# Text conversion (always works)
if to_type == SemanticType.TEXT:
return str(value)
# No conversion available
return value
def _normalize_field_name(
self,
customer_field_name: str,
customer_id: str
) -> Optional[str]:
"""Map a customer-specific field name to its semantic concept.
Args:
customer_field_name: Field name in customer schema
customer_id: Customer ID
Returns:
Semantic concept ID or None if not mapped
"""
# Check all concepts for this customer
for concept_id in self.knowledge_graph.concepts.keys():
mapping = self.knowledge_graph.get_mapping(concept_id, customer_id)
if mapping and mapping.column_name == customer_field_name:
return concept_id
return None
def _normalize_industry_name(self, industry: Optional[str]) -> Optional[str]:
"""Normalize industry names to common format.
Args:
industry: Industry name from customer data
Returns:
Normalized industry name
"""
if not industry:
return None
# Convert to lowercase for comparison
industry_lower = industry.lower().strip()
# Map common variations
industry_mapping = {
"tech": "Technology",
"technology": "Technology",
"it": "Technology",
"information technology": "Technology",
"healthcare": "Healthcare",
"health": "Healthcare",
"medical": "Healthcare",
"finance": "Financial Services",
"financial": "Financial Services",
"financial services": "Financial Services",
"banking": "Financial Services",
"retail": "Retail",
"manufacturing": "Manufacturing",
"mfg": "Manufacturing",
"education": "Education",
"edu": "Education",
"government": "Government",
"gov": "Government",
"public sector": "Government",
}
return industry_mapping.get(industry_lower, industry.title())
def _harmonize_row(
self,
row: Dict[str, Any],
customer_id: str,
field_mappings: Dict[str, str]
) -> Dict[str, Any]:
"""Harmonize a single row of data.
Args:
row: Raw row data from customer database
customer_id: Customer ID
field_mappings: Map of customer field names to concept IDs
Returns:
Harmonized row with normalized field names and values
"""
harmonized = {}
# First, map all fields that have concept mappings
for customer_field, concept_id in field_mappings.items():
if customer_field in row:
value = row[customer_field]
# Special handling for industry
if concept_id == "industry_sector":
harmonized[concept_id] = self._normalize_industry_name(value)
else:
# Normalize the value
normalized = self._normalize_value(value, customer_id, concept_id)
harmonized[concept_id] = normalized.normalized_value
else:
# Field not present in row
harmonized[concept_id] = None
# Second, include any unmapped columns (like aggregation results)
# These are columns like count_contract_identifier, sum_contract_value, etc.
for field_name, value in row.items():
if field_name not in field_mappings:
# This is an unmapped column (likely an aggregation)
# Include it as-is with its original name
harmonized[field_name] = value
return harmonized