Spaces:
Running
Running
| """ | |
| Explanation generator tool for causal inference methods. | |
| This tool generates explanations for the selected causal inference method, | |
| including what the method does, its assumptions, and how it will be applied. | |
| """ | |
| from typing import Dict, Any, Optional, List, Union | |
| from langchain.tools import tool | |
| import logging | |
| from auto_causal.components.explanation_generator import generate_explanation | |
| from auto_causal.components.state_manager import create_workflow_state_update | |
| from auto_causal.config import get_llm_client | |
| # Import shared models from central location | |
| from auto_causal.models import ( | |
| Variables, | |
| TemporalStructure, # Needed indirectly by DatasetAnalysis | |
| DatasetInfo, # Needed indirectly by DatasetAnalysis | |
| DatasetAnalysis, | |
| MethodInfo, | |
| ExplainerInput # Keep for type hinting arguments | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # --- Removed local Pydantic definitions --- | |
| # class Variables(BaseModel): ... | |
| # class TemporalStructure(BaseModel): ... | |
| # class DatasetInfo(BaseModel): ... | |
| # class DatasetAnalysis(BaseModel): ... | |
| # class MethodInfo(BaseModel): ... | |
| # class ExplainerInput(BaseModel): ... | |
| # --- Tool Definition --- | |
| # Change signature to accept individual arguments | |
| def explanation_generator_tool( | |
| method_info: MethodInfo, | |
| variables: Variables, | |
| results: Dict[str, Any], | |
| dataset_analysis: DatasetAnalysis, | |
| validation_info: Optional[Dict[str, Any]] = None, | |
| dataset_description: Optional[str] = None, | |
| original_query: Optional[str] = None # Get original query if passed | |
| ) -> Dict[str, Any]: | |
| """ | |
| Generate a single comprehensive explanation string using structured Pydantic input. | |
| Args: | |
| method_info: Pydantic model with method details. | |
| variables: Pydantic model with identified variables. | |
| results: Dictionary containing numerical results from execution. | |
| dataset_analysis: Pydantic model with dataset analysis results. | |
| validation_info: Optional dictionary with validation results. | |
| dataset_description: Optional string description of the dataset. | |
| original_query: Optional original user query string. | |
| Returns: | |
| Dictionary with the final explanation text, context, and workflow state. | |
| """ | |
| logger.info("Running explainer_tool with direct arguments...") | |
| # Use arguments directly, dump models to dicts if needed by component | |
| method_info_dict = method_info.model_dump() | |
| print('------------------------') | |
| print(method_info_dict) | |
| print('------------------------') | |
| validation_result_dict = validation_info # Already dict or None | |
| variables_dict = variables.model_dump() | |
| # results is already a dict | |
| dataset_analysis_dict = dataset_analysis.model_dump() | |
| # dataset_description is already str or None | |
| # Include original_query in variables_dict if the component expects it there | |
| if original_query: | |
| variables_dict['original_query'] = original_query | |
| # Get LLM instance if needed by generate_explanation | |
| llm_instance = None | |
| try: | |
| llm_instance = get_llm_client() | |
| except Exception as e: | |
| logger.warning(f"Could not get LLM client for explainer: {e}") | |
| # Call component to generate the single explanation string | |
| try: | |
| explanation_dict = generate_explanation( | |
| method_info=method_info_dict, | |
| validation_result=validation_result_dict, | |
| variables=variables_dict, | |
| results=results, # Pass results dict directly | |
| dataset_analysis=dataset_analysis_dict, | |
| dataset_description=dataset_description, | |
| llm=llm_instance # Pass LLM if component uses it | |
| ) | |
| if not isinstance(explanation_dict, dict): | |
| raise TypeError(f"generate_explanation component did not return a dict. Got: {type(explanation_dict)}") | |
| except Exception as e: | |
| logger.error(f"Error during generate_explanation execution: {e}", exc_info=True) | |
| # Provide missing args for the error state update | |
| workflow_update = create_workflow_state_update( | |
| current_step="result_explanation", | |
| step_completed_flag=False, | |
| error=f"Component failed: {e}", | |
| next_tool="explanation_generator_tool", # Indicate failed tool | |
| next_step_reason=f"Explanation generation component failed: {e}" # Provide reason | |
| ) | |
| # Return structure consistent with success case, but with error info | |
| return { | |
| "error": f"Explanation generation component failed: {e}", | |
| # Pass necessary context for potential retry or next step | |
| "query": original_query or "N/A", | |
| "method": method_info_dict.get('selected_method', "N/A"), | |
| "results": results, # Include results even if explanation failed | |
| "explanation": {"error": str(e)}, # Include error in explanation part | |
| "dataset_analysis": dataset_analysis_dict, | |
| "dataset_description": dataset_description, | |
| **workflow_update.get('workflow_state', {}) | |
| } | |
| # Create workflow state update | |
| workflow_update = create_workflow_state_update( | |
| current_step="result_explanation", | |
| step_completed_flag="results_explained", | |
| next_tool="output_formatter_tool", # Step 8: Format output | |
| next_step_reason="Finally, we need to format the output for presentation" | |
| ) | |
| # Prepare result dict for the next tool (formatter) | |
| result_for_formatter = { | |
| # Pass the necessary pieces for the formatter | |
| "query": original_query or "N/A", # Use original_query directly | |
| "method": method_info_dict.get('selected_method', 'N/A'), | |
| "results": results, # Pass the numerical results directly | |
| "explanation": explanation_dict, # Pass the structured explanation | |
| # Avoid passing full analysis if not needed by formatter? Check formatter needs. | |
| # For now, keep them. | |
| "dataset_analysis": dataset_analysis_dict, | |
| "dataset_description": dataset_description | |
| } | |
| # Add workflow state to the result | |
| result_for_formatter.update(workflow_update) | |
| logger.info("explanation_generator_tool finished successfully.") | |
| return result_for_formatter |