Spaces:
Running
Running
| """ | |
| Method Executor Tool for the causal inference agent. | |
| Executes the selected causal inference method using its implementation function. | |
| """ | |
| import pandas as pd | |
| from typing import Dict, Any, Optional, List, Union | |
| from langchain.tools import tool | |
| import traceback # For error logging | |
| import logging # Add logging | |
| # Import the mapping and potentially preprocessing utils | |
| from auto_causal.methods import METHOD_MAPPING | |
| from auto_causal.methods.utils import preprocess_data # Assuming preprocess exists | |
| from auto_causal.components.state_manager import create_workflow_state_update | |
| from auto_causal.config import get_llm_client # IMPORT LLM Client Factory | |
| # Import shared models from central location | |
| from auto_causal.models import ( | |
| Variables, | |
| TemporalStructure, # Needed indirectly by DatasetAnalysis | |
| DatasetInfo, # Needed indirectly by DatasetAnalysis | |
| DatasetAnalysis, | |
| MethodExecutorInput | |
| ) | |
| # Add this module-level variable, typically near imports or at the top | |
| CURRENT_OUTPUT_LOG_FILE = None | |
| logger = logging.getLogger(__name__) | |
| def method_executor_tool(inputs: MethodExecutorInput, original_query: Optional[str] = None) -> Dict[str, Any]: # Use Pydantic Input | |
| '''Execute the selected causal inference method function using structured input. | |
| Args: | |
| inputs: Pydantic model containing method, variables, dataset_path, | |
| dataset_analysis, and dataset_description. | |
| Returns: | |
| Dict with numerical results, context for next step, and workflow state. | |
| ''' | |
| # Access data from input model | |
| method = inputs.method | |
| variables_dict = inputs.variables.model_dump() | |
| dataset_path = inputs.dataset_path | |
| dataset_analysis_dict = inputs.dataset_analysis.model_dump() | |
| dataset_description_str = inputs.dataset_description | |
| validation_info = inputs.validation_info # Can be passed if needed | |
| logger.info(f"Executing method: {method}") | |
| try: | |
| # --- Get LLM Instance --- | |
| llm_instance = None | |
| try: | |
| llm_instance = get_llm_client() | |
| except Exception as llm_e: | |
| logger.warning(f"Could not get LLM client in method_executor_tool: {llm_e}. LLM-dependent features in method will be disabled.") | |
| # 1. Load Data | |
| if not dataset_path: | |
| raise ValueError("Dataset path is missing.") | |
| df = pd.read_csv(dataset_path) | |
| # 2. Extract Key Variables needed by estimate_func signature | |
| treatment = variables_dict.get("treatment_variable") | |
| outcome = variables_dict.get("outcome_variable") | |
| covariates = variables_dict.get("covariates", []) | |
| query_str = original_query if original_query is not None else inputs.original_query | |
| if not all([treatment, outcome]): | |
| raise ValueError("Treatment or Outcome variable not found in 'variables' dict.") | |
| # 3. Preprocess Data | |
| required_cols_for_method = [treatment, outcome] + covariates | |
| # Add method-specific required vars from the variables_dict | |
| if method == "instrumental_variable" and variables_dict.get("instrument_variable"): | |
| required_cols_for_method.append(variables_dict["instrument_variable"]) | |
| elif method == "regression_discontinuity" and variables_dict.get("running_variable"): | |
| required_cols_for_method.append(variables_dict["running_variable"]) | |
| missing_df_cols = [col for col in required_cols_for_method if col not in df.columns] | |
| if missing_df_cols: | |
| raise ValueError(f"Dataset at {dataset_path} is missing required columns for method '{method}': {missing_df_cols}") | |
| df_processed, updated_treatment, updated_outcome, updated_covariates, column_mappings = \ | |
| preprocess_data(df, treatment, outcome, covariates, verbose=False) | |
| # 4. Get the correct method execution function | |
| if method not in METHOD_MAPPING: | |
| raise ValueError(f"Method '{method}' not found in METHOD_MAPPING.") | |
| estimate_func = METHOD_MAPPING[method] | |
| # 5. Execute the method | |
| # Pass only necessary args from variables_dict as kwargs | |
| # (e.g., instrument_variable, running_variable, cutoff_value, etc.) | |
| # Avoid passing the entire variables_dict as estimate_func expects specific args | |
| kwargs_for_method = {} | |
| for key in ["instrument_variable", "time_variable", "group_variable", | |
| "running_variable", "cutoff_value"]: | |
| if key in variables_dict and variables_dict[key] is not None: | |
| kwargs_for_method[key] = variables_dict[key] | |
| # Add new fields from the Variables model (which is inputs.variables) | |
| if hasattr(inputs, 'variables'): # ensure variables object exists on inputs | |
| if inputs.variables.treatment_reference_level is not None: | |
| kwargs_for_method['treatment_reference_level'] = inputs.variables.treatment_reference_level | |
| if inputs.variables.interaction_term_suggested is not None: # boolean, so check for None to allow False | |
| kwargs_for_method['interaction_term_suggested'] = inputs.variables.interaction_term_suggested | |
| if inputs.variables.interaction_variable_candidate is not None: | |
| kwargs_for_method['interaction_variable_candidate'] = inputs.variables.interaction_variable_candidate | |
| # Add query if needed by llm_assist functions within the method | |
| kwargs_for_method['query'] = query_str | |
| kwargs_for_method['column_mappings'] = column_mappings | |
| results_dict = estimate_func( | |
| df=df_processed, | |
| treatment=updated_treatment, | |
| outcome=updated_outcome, | |
| covariates=updated_covariates, | |
| dataset_description=dataset_description_str, | |
| query_str=query_str, | |
| llm=llm_instance, | |
| **kwargs_for_method # Pass specific args needed by the method | |
| ) | |
| # 6. Prepare output | |
| logger.info(f"Method execution successful. Effect estimate: {results_dict.get('effect_estimate')}") | |
| # Add workflow state | |
| workflow_update = create_workflow_state_update( | |
| current_step="method_execution", | |
| step_completed_flag="method_executed", | |
| next_tool="explainer_tool", | |
| next_step_reason="Now we need to explain the results and their implications" | |
| ) | |
| # --- Prepare Output Dictionary --- | |
| # Structure required by explainer_tool: context + nested "results" | |
| final_output = { | |
| # Nested dictionary for numerical results and diagnostics | |
| "results": { | |
| # Core estimation results (extracted from results_dict) | |
| "effect_estimate": results_dict.get("effect_estimate"), | |
| "confidence_interval": results_dict.get("confidence_interval"), | |
| "standard_error": results_dict.get("standard_error"), | |
| "p_value": results_dict.get("p_value"), | |
| "method_used": results_dict.get("method_used"), | |
| "llm_assumption_check": results_dict.get("llm_assumption_check"), | |
| "raw_results": results_dict.get("raw_results"), | |
| # Diagnostics and Refutation results | |
| "diagnostics": results_dict.get("diagnostics"), | |
| "refutation_results": results_dict.get("refutation_results") | |
| }, | |
| # Top-level context to be passed along | |
| "variables": variables_dict, | |
| "dataset_analysis": dataset_analysis_dict, | |
| "dataset_description": dataset_description_str, | |
| "validation_info": validation_info, # Pass validation info | |
| "original_query": inputs.original_query, | |
| "column_mappings": column_mappings # Add column_mappings to the output | |
| # Workflow state will be added next | |
| } | |
| # Add workflow state to the final output | |
| final_output.update(workflow_update.get('workflow_state', {})) | |
| # --- Logging logic (moved from output_formatter.py) --- | |
| # Prepare a summary dict for logging | |
| summary_keys = {"query", "method_used", "causal_effect", "standard_error", "confidence_interval"} | |
| # Try to get these from the available context | |
| summary_dict = { | |
| "query": inputs.original_query if hasattr(inputs, 'original_query') else None, | |
| "method_used": results_dict.get("method_used"), | |
| "causal_effect": results_dict.get("effect_estimate"), | |
| "standard_error": results_dict.get("standard_error"), | |
| "confidence_interval": results_dict.get("confidence_interval") | |
| } | |
| print(f"summary_dict: {summary_dict}") | |
| print(f"CURRENT_OUTPUT_LOG_FILE: {CURRENT_OUTPUT_LOG_FILE}") | |
| if CURRENT_OUTPUT_LOG_FILE and summary_dict: | |
| try: | |
| import json | |
| log_entry = {"type": "analysis_result", "data": summary_dict} | |
| with open(CURRENT_OUTPUT_LOG_FILE, mode='a', encoding='utf-8') as log_file: | |
| log_file.write('\n' + json.dumps(log_entry) + '\n') | |
| except Exception as e: | |
| print(f"[ERROR] method_executor_tool.py: Failed to write analysis results to log file '{CURRENT_OUTPUT_LOG_FILE}': {e}") | |
| return final_output | |
| except Exception as e: | |
| error_message = f"Error executing method {method}: {str(e)}" | |
| logger.error(error_message, exc_info=True) | |
| # Return error state, include context if available | |
| workflow_update = create_workflow_state_update( | |
| current_step="method_execution", | |
| step_completed_flag=False, | |
| next_tool="explainer_tool", # Or error handler? | |
| next_step_reason=f"Failed during method execution: {error_message}" | |
| ) | |
| # Ensure error output still contains necessary context keys if possible | |
| error_result = {"error": error_message, | |
| "variables": variables_dict if 'variables_dict' in locals() else {}, | |
| "dataset_analysis": dataset_analysis_dict if 'dataset_analysis_dict' in locals() else {}, | |
| "dataset_description": dataset_description_str if 'dataset_description_str' in locals() else None, | |
| "original_query": inputs.original_query if hasattr(inputs, 'original_query') else None, | |
| "column_mappings": column_mappings if 'column_mappings' in locals() else {} # Also add to error output | |
| } | |
| error_result.update(workflow_update.get('workflow_state', {})) | |
| return error_result |