Spaces:
Paused
Paused
| """Database executor for running queries against customer databases.""" | |
| import sqlite3 | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| from schema_translator.config import get_config | |
| from schema_translator.models import QueryResult | |
| class DatabaseExecutor: | |
| """Executes SQL queries against customer databases.""" | |
| def __init__(self): | |
| """Initialize the database executor.""" | |
| self.config = get_config() | |
| self._connections: Dict[str, sqlite3.Connection] = {} | |
| def execute_query( | |
| self, | |
| customer_id: str, | |
| sql: str | |
| ) -> QueryResult: | |
| """Execute a SQL query for a specific customer. | |
| Args: | |
| customer_id: Customer identifier | |
| sql: SQL query to execute | |
| Returns: | |
| QueryResult with data and execution metadata | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Get database connection | |
| conn = self._get_connection(customer_id) | |
| cursor = conn.cursor() | |
| # Execute query | |
| cursor.execute(sql) | |
| # Fetch results | |
| rows = cursor.fetchall() | |
| # Get column names | |
| column_names = [desc[0] for desc in cursor.description] if cursor.description else [] | |
| # Convert to list of dictionaries | |
| data = [] | |
| for row in rows: | |
| data.append(dict(zip(column_names, row))) | |
| # Calculate execution time | |
| execution_time_ms = (time.time() - start_time) * 1000 | |
| return QueryResult( | |
| customer_id=customer_id, | |
| data=data, | |
| sql_executed=sql, | |
| execution_time_ms=execution_time_ms, | |
| row_count=len(data) | |
| ) | |
| except Exception as e: | |
| # Calculate execution time even for errors | |
| execution_time_ms = (time.time() - start_time) * 1000 | |
| return QueryResult( | |
| customer_id=customer_id, | |
| data=[], | |
| sql_executed=sql, | |
| execution_time_ms=execution_time_ms, | |
| row_count=0, | |
| error=str(e) | |
| ) | |
| def execute_for_all_customers( | |
| self, | |
| sql_by_customer: Dict[str, str] | |
| ) -> List[QueryResult]: | |
| """Execute queries for multiple customers. | |
| Args: | |
| sql_by_customer: Dictionary mapping customer_id to SQL query | |
| Returns: | |
| List of QueryResult objects | |
| """ | |
| results = [] | |
| for customer_id, sql in sql_by_customer.items(): | |
| result = self.execute_query(customer_id, sql) | |
| results.append(result) | |
| return results | |
| def execute_raw_query( | |
| self, | |
| customer_id: str, | |
| sql: str | |
| ) -> List[Dict[str, Any]]: | |
| """Execute a raw SQL query and return results directly. | |
| Simpler interface for direct queries without QueryResult wrapper. | |
| Args: | |
| customer_id: Customer identifier | |
| sql: SQL query to execute | |
| Returns: | |
| List of result dictionaries | |
| """ | |
| result = self.execute_query(customer_id, sql) | |
| if result.error: | |
| raise RuntimeError(f"Query failed: {result.error}") | |
| return result.data | |
| def test_connection(self, customer_id: str) -> bool: | |
| """Test if database connection is working. | |
| Args: | |
| customer_id: Customer identifier | |
| Returns: | |
| True if connection works, False otherwise | |
| """ | |
| try: | |
| conn = self._get_connection(customer_id) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT 1") | |
| cursor.fetchone() | |
| return True | |
| except Exception: | |
| return False | |
| def get_table_info(self, customer_id: str) -> Dict[str, List[Dict[str, str]]]: | |
| """Get information about tables in a customer database. | |
| Args: | |
| customer_id: Customer identifier | |
| Returns: | |
| Dictionary mapping table names to column info | |
| """ | |
| conn = self._get_connection(customer_id) | |
| cursor = conn.cursor() | |
| # Get all tables | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| table_info = {} | |
| for table in tables: | |
| cursor.execute(f"PRAGMA table_info({table})") | |
| columns = [] | |
| for row in cursor.fetchall(): | |
| columns.append({ | |
| "name": row[1], | |
| "type": row[2], | |
| "notnull": bool(row[3]), | |
| "default": row[4], | |
| "primary_key": bool(row[5]) | |
| }) | |
| table_info[table] = columns | |
| return table_info | |
| def count_rows(self, customer_id: str, table_name: str) -> int: | |
| """Count rows in a table. | |
| Args: | |
| customer_id: Customer identifier | |
| table_name: Table name | |
| Returns: | |
| Number of rows | |
| """ | |
| conn = self._get_connection(customer_id) | |
| cursor = conn.cursor() | |
| cursor.execute(f"SELECT COUNT(*) FROM {table_name}") | |
| return cursor.fetchone()[0] | |
| def _get_connection(self, customer_id: str) -> sqlite3.Connection: | |
| """Get or create a database connection for a customer. | |
| Args: | |
| customer_id: Customer identifier | |
| Returns: | |
| SQLite connection | |
| Raises: | |
| FileNotFoundError: If database file doesn't exist | |
| """ | |
| # Reuse existing connection if available | |
| if customer_id in self._connections: | |
| return self._connections[customer_id] | |
| # Get database path | |
| db_path = self.config.get_database_path(customer_id) | |
| if not db_path.exists(): | |
| raise FileNotFoundError(f"Database not found: {db_path}") | |
| # Create connection | |
| conn = sqlite3.connect(str(db_path)) | |
| # Enable foreign keys | |
| conn.execute("PRAGMA foreign_keys = ON") | |
| # Store connection | |
| self._connections[customer_id] = conn | |
| return conn | |
| def close_all_connections(self) -> None: | |
| """Close all database connections.""" | |
| for conn in self._connections.values(): | |
| conn.close() | |
| self._connections.clear() | |
| def close_connection(self, customer_id: str) -> None: | |
| """Close a specific customer's database connection. | |
| Args: | |
| customer_id: Customer identifier | |
| """ | |
| if customer_id in self._connections: | |
| self._connections[customer_id].close() | |
| del self._connections[customer_id] | |
| def __del__(self): | |
| """Cleanup connections on deletion.""" | |
| self.close_all_connections() | |
| def __enter__(self): | |
| """Context manager entry.""" | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| """Context manager exit - close connections.""" | |
| self.close_all_connections() | |