Spaces:
Running
Running
| """ | |
| Tool for analyzing datasets for causal inference. | |
| This module provides a LangChain tool for analyzing datasets to detect | |
| characteristics relevant for causal inference, such as temporal structure, | |
| potential instrumental variables, and variable relationships. | |
| """ | |
| from typing import Dict, Any, Optional | |
| from langchain.tools import tool | |
| import logging | |
| from auto_causal.components.dataset_analyzer import analyze_dataset | |
| from auto_causal.components.state_manager import create_workflow_state_update | |
| from langchain_core.language_models import BaseChatModel | |
| from auto_causal.config import get_llm_client | |
| # Import the required Pydantic models | |
| from auto_causal.models import DatasetAnalysis, DatasetAnalyzerOutput | |
| from auto_causal import models | |
| logger = logging.getLogger(__name__) | |
| def dataset_analyzer_tool(dataset_path: str, | |
| dataset_description: Optional[str] = None, | |
| original_query: Optional[str] = None) -> DatasetAnalyzerOutput: | |
| """ | |
| Analyze dataset to identify important characteristics for causal inference. | |
| This tool loads the dataset, calculates summary statistics, checks for temporal | |
| structure, identifies potential treatments/outcomes/instruments, and assesses | |
| variable relationships relevant for selecting a causal method. | |
| Args: | |
| dataset_path: Path to the dataset file. | |
| dataset_description: Optional description string from input. | |
| llm: Optional LLM client for enhanced analysis. | |
| Returns: | |
| A Pydantic model containing the structured dataset analysis results and workflow state. | |
| """ | |
| logger.info(f"Running dataset_analyzer_tool on path: {dataset_path}") | |
| # Call the component function with the LLM if available | |
| llm = get_llm_client() | |
| try: | |
| # Call the component function | |
| analysis_dict = analyze_dataset(dataset_path, llm_client=llm, dataset_description=dataset_description, original_query=original_query) | |
| # Check for errors returned explicitly by the component | |
| if isinstance(analysis_dict, dict) and "error" in analysis_dict: | |
| logger.error(f"Dataset analysis component failed: {analysis_dict['error']}") | |
| raise ValueError(analysis_dict['error']) | |
| # Validate and structure the analysis using Pydantic | |
| # This assumes analyze_dataset returns a dict compatible with DatasetAnalysis | |
| # Handle potential missing keys or type mismatches gracefully | |
| analysis_results_model = DatasetAnalysis(**analysis_dict) | |
| except Exception as e: | |
| logger.error(f"Error during dataset analysis or Pydantic model creation: {e}", exc_info=True) | |
| error_state = create_workflow_state_update( | |
| current_step="data_analysis", | |
| step_completed_flag=False, | |
| next_tool="dataset_analyzer_tool", # Retry or error handler? | |
| next_step_reason=f"Dataset analysis failed: {e}" | |
| ) | |
| minimal_info = models.DatasetInfo(num_rows=0, num_columns=0, file_path=dataset_path, file_name="unknown") | |
| empty_temporal = models.TemporalStructure(has_temporal_structure=False, temporal_columns=[], is_panel_data=False) | |
| error_analysis = models.DatasetAnalysis( | |
| dataset_info=minimal_info, | |
| columns=[], | |
| potential_treatments=[], | |
| potential_outcomes=[], | |
| temporal_structure_detected=False, | |
| panel_data_detected=False, | |
| potential_instruments_detected=False, | |
| discontinuities_detected=False, | |
| temporal_structure=empty_temporal, | |
| sample_size=0, | |
| num_covariates_estimate=0 | |
| ) | |
| return DatasetAnalyzerOutput( | |
| analysis_results=error_analysis, | |
| dataset_description=dataset_description, | |
| workflow_state=error_state.get('workflow_state', {}) | |
| ) | |
| # Create workflow state update for success | |
| workflow_update = create_workflow_state_update( | |
| current_step="data_analysis", | |
| step_completed_flag="dataset_analyzed", | |
| next_tool="query_interpreter_tool", | |
| next_step_reason="Now we need to map query concepts to actual dataset variables" | |
| ) | |
| # Construct the final Pydantic output object | |
| output = DatasetAnalyzerOutput( | |
| analysis_results=analysis_results_model, | |
| dataset_description=dataset_description, | |
| dataset_path=dataset_path, | |
| workflow_state=workflow_update.get('workflow_state', {}) | |
| ) | |
| # print(output) | |
| logger.info("dataset_analyzer_tool finished successfully.") | |
| return output |