Spaces:
Running
Running
| import pandas as pd | |
| import statsmodels.api as sm | |
| from statsmodels.sandbox.regression.gmm import IV2SLS | |
| from dowhy import CausalModel # Primary path | |
| from typing import Dict, Any, List, Union, Optional | |
| import logging | |
| from langchain.chat_models.base import BaseChatModel | |
| from .diagnostics import run_iv_diagnostics | |
| from .llm_assist import identify_instrument_variable, validate_instrument_assumptions_qualitative, interpret_iv_results | |
| logger = logging.getLogger(__name__) | |
| def build_iv_graph_gml(treatment: str, outcome: str, instruments: List[str], covariates: List[str]) -> str: | |
| """ | |
| Constructs a GML string representing the causal graph for IV. | |
| Assumptions: | |
| - Instruments cause Treatment | |
| - Covariates cause Treatment and Outcome | |
| - Treatment causes Outcome | |
| - Instruments do NOT directly cause Outcome (Exclusion) | |
| - Instruments are NOT caused by Covariates (can be relaxed if needed) | |
| - Unobserved Confounder (U) affects Treatment and Outcome | |
| Args: | |
| treatment: Name of the treatment variable. | |
| outcome: Name of the outcome variable. | |
| instruments: List of instrument variable names. | |
| covariates: List of covariate names. | |
| Returns: | |
| A GML graph string. | |
| """ | |
| nodes = [] | |
| edges = [] | |
| # Define nodes - ensure no duplicates if a variable is both instrument and covariate (SHOULD NOT HAPPEN) | |
| # Use a set to ensure unique variable names | |
| all_vars_set = set([treatment, outcome] + instruments + covariates + ['U']) | |
| all_vars = list(all_vars_set) | |
| for var in all_vars: | |
| nodes.append(f'node [ id "{var}" label "{var}" ]') | |
| # Define edges | |
| # Instruments -> Treatment | |
| for inst in instruments: | |
| edges.append(f'edge [ source "{inst}" target "{treatment}" ]') | |
| # Covariates -> Treatment | |
| for cov in covariates: | |
| # Ensure we don't add self-loops or duplicate edges if cov == treatment (shouldn't happen) | |
| if cov != treatment: | |
| edges.append(f'edge [ source "{cov}" target "{treatment}" ]') | |
| # Covariates -> Outcome | |
| for cov in covariates: | |
| if cov != outcome: | |
| edges.append(f'edge [ source "{cov}" target "{outcome}" ]') | |
| # Treatment -> Outcome | |
| edges.append(f'edge [ source "{treatment}" target "{outcome}" ]') | |
| # Unobserved Confounder -> Treatment and Outcome | |
| edges.append(f'edge [ source "U" target "{treatment}" ]') | |
| edges.append(f'edge [ source "U" target "{outcome}" ]') | |
| # Core IV Assumption: Instruments are NOT caused by U (implicitly handled by not adding edge) | |
| # Core IV Assumption: Instruments do NOT directly cause Outcome (handled by not adding edge) | |
| # Format nodes and edges with indentation before inserting into f-string | |
| formatted_nodes = '\n '.join(nodes) | |
| formatted_edges = '\n '.join(edges) | |
| gml_string = f""" | |
| graph [ | |
| directed 1 | |
| {formatted_nodes} | |
| {formatted_edges} | |
| ] | |
| """ | |
| # Convert print to logger | |
| logger.debug("\n--- Generated GML Graph ---") | |
| logger.debug(gml_string) | |
| logger.debug("-------------------------\n") | |
| return gml_string | |
| def format_iv_results(estimate: Optional[float], raw_results: Dict, diagnostics: Dict, treatment: str, outcome: str, instrument: List[str], method_used: str, llm: Optional[BaseChatModel] = None) -> Dict[str, Any]: | |
| """ | |
| Formats the results from IV estimation into a standardized dictionary. | |
| Args: | |
| estimate: The point estimate of the causal effect. | |
| raw_results: Dictionary containing raw outputs from DoWhy/statsmodels. | |
| diagnostics: Dictionary containing diagnostic results. | |
| treatment: Name of the treatment variable. | |
| outcome: Name of the outcome variable. | |
| instrument: List of instrument variable names. | |
| method_used: 'dowhy' or 'statsmodels'. | |
| llm: Optional LLM instance for interpretation. | |
| Returns: | |
| Standardized results dictionary. | |
| """ | |
| formatted = { | |
| "effect_estimate": estimate, | |
| "treatment_variable": treatment, | |
| "outcome_variable": outcome, | |
| "instrument_variables": instrument, | |
| "method_used": method_used, | |
| "diagnostics": diagnostics, | |
| "raw_results": {k: str(v) for k, v in raw_results.items() if "object" not in k}, # Avoid serializing large objects | |
| "confidence_interval": None, | |
| "standard_error": None, | |
| "p_value": None, | |
| "interpretation": "Placeholder" | |
| } | |
| # Extract details from statsmodels results if available | |
| sm_results = raw_results.get('statsmodels_results_object') | |
| if method_used == 'statsmodels' and sm_results: | |
| try: | |
| # Use .bse for standard error in statsmodels results | |
| formatted["standard_error"] = float(sm_results.bse[treatment]) | |
| formatted["p_value"] = float(sm_results.pvalues[treatment]) | |
| conf_int = sm_results.conf_int().loc[treatment].tolist() | |
| formatted["confidence_interval"] = [float(ci) for ci in conf_int] | |
| except AttributeError as e: | |
| logger.warning(f"Could not extract all details from statsmodels results object (likely missing attribute): {e}") | |
| except Exception as e: | |
| logger.warning(f"Error extracting details from statsmodels results: {e}") | |
| # Extract details from DoWhy results if available | |
| # Note: DoWhy's CausalEstimate object structure needs inspection | |
| dw_results = raw_results.get('dowhy_results_object') | |
| if method_used == 'dowhy' and dw_results: | |
| try: | |
| # Attempt common attributes, may need adjustment based on DoWhy version/output | |
| if hasattr(dw_results, 'stderr'): | |
| formatted["standard_error"] = float(dw_results.stderr) | |
| if hasattr(dw_results, 'p_value'): | |
| formatted["p_value"] = float(dw_results.p_value) | |
| if hasattr(dw_results, 'conf_intervals'): | |
| # Assuming it's stored similarly to statsmodels, might need adjustment | |
| ci = dw_results.conf_intervals().loc[treatment].tolist() # Fictional attribute/method - check DoWhy docs! | |
| formatted["confidence_interval"] = [float(c) for c in ci] | |
| elif hasattr(dw_results, 'get_confidence_intervals'): | |
| ci = dw_results.get_confidence_intervals() # Check DoWhy docs for format | |
| # Check format of ci before converting | |
| if isinstance(ci, (list, tuple)) and len(ci) == 2: | |
| formatted["confidence_interval"] = [float(c) for c in ci] # Adapt parsing | |
| else: | |
| logger.warning(f"Could not parse confidence intervals from DoWhy object: {ci}") | |
| except Exception as e: | |
| logger.warning(f"Could not extract all details from DoWhy results: {e}. Structure might be different.", exc_info=True) | |
| # Avoid printing dir in production code, use logger.debug if needed for dev | |
| # logger.debug(f"DoWhy result object dir(): {dir(dw_results)}") | |
| # Generate LLM interpretation - pass llm object | |
| if estimate is not None: | |
| formatted["interpretation"] = interpret_iv_results(formatted, diagnostics, llm=llm) | |
| else: | |
| formatted["interpretation"] = "Estimation failed, cannot interpret results." | |
| return formatted | |
| def estimate_effect( | |
| df: pd.DataFrame, | |
| treatment: str, | |
| outcome: str, | |
| covariates: List[str], | |
| query: Optional[str] = None, | |
| dataset_description: Optional[str] = None, | |
| llm: Optional[BaseChatModel] = None, | |
| **kwargs | |
| ) -> Dict[str, Any]: | |
| instrument = kwargs.get('instrument_variable') | |
| if not instrument: | |
| return {"error": "Instrument variable ('instrument_variable') not found in kwargs.", "method_used": "none", "diagnostics": {}} | |
| instrument_list = [instrument] if isinstance(instrument, str) else instrument | |
| valid_instruments = [inst for inst in instrument_list if isinstance(inst, str)] | |
| clean_covariates = [cov for cov in covariates if cov not in valid_instruments] | |
| logger.info(f"\n--- Starting Instrumental Variable Estimation ---") | |
| logger.info(f"Treatment: {treatment}, Outcome: {outcome}, Instrument(s): {valid_instruments}, Original Covariates: {covariates}, Cleaned Covariates: {clean_covariates}") | |
| results = {} | |
| method_used = "none" | |
| sm_results_obj = None | |
| dw_results_obj = None | |
| identified_estimand = None # Initialize | |
| model = None # Initialize | |
| refutation_results = {} # Initialize | |
| # --- Input Validation --- | |
| required_cols = [treatment, outcome] + valid_instruments + clean_covariates | |
| missing_cols = [col for col in required_cols if col not in df.columns] | |
| if missing_cols: | |
| return {"error": f"Missing required columns in DataFrame: {missing_cols}", "method_used": method_used, "diagnostics": {}} | |
| if not valid_instruments: | |
| return {"error": "Instrument variable(s) must be provided and valid.", "method_used": method_used, "diagnostics": {}} | |
| # --- LLM Pre-Checks --- | |
| if query and llm: | |
| qualitative_check = validate_instrument_assumptions_qualitative(treatment, outcome, valid_instruments, clean_covariates, query, llm=llm) | |
| results['llm_assumption_check'] = qualitative_check | |
| logger.info(f"LLM Qualitative Assumption Check: {qualitative_check}") | |
| # --- Build Graph and Instantiate CausalModel (Do this before estimation attempts) --- | |
| # This allows using identify_effect and refute_estimate even if DoWhy estimation fails | |
| try: | |
| graph = build_iv_graph_gml(treatment, outcome, valid_instruments, clean_covariates) | |
| if not graph: | |
| raise ValueError("Failed to build GML graph for DoWhy.") | |
| model = CausalModel(data=df, treatment=treatment, outcome=outcome, graph=graph) | |
| # Identify Effect (essential for refutation later) | |
| identified_estimand = model.identify_effect(proceed_when_unidentifiable=True) | |
| logger.debug("\nDoWhy Identified Estimand:") | |
| logger.debug(identified_estimand) | |
| if not identified_estimand: | |
| raise ValueError("DoWhy could not identify a valid estimand.") | |
| except Exception as model_init_e: | |
| logger.error(f"Failed to initialize CausalModel or identify effect: {model_init_e}", exc_info=True) | |
| # Cannot proceed without model/estimand for DoWhy or refutation | |
| results['error'] = f"Failed to initialize CausalModel: {model_init_e}" | |
| # Attempt statsmodels anyway? Or return error? Let's try statsmodels. | |
| pass # Allow falling through to statsmodels if desired | |
| # --- Primary Path: DoWhy Estimation --- | |
| if model and identified_estimand and not kwargs.get('force_statsmodels', False): | |
| logger.info("\nAttempting estimation with DoWhy...") | |
| try: | |
| dw_results_obj = model.estimate_effect( | |
| identified_estimand, | |
| method_name="iv.instrumental_variable", | |
| method_params={'iv_instrument_name': valid_instruments} | |
| ) | |
| logger.debug("\nDoWhy Estimation Result:") | |
| logger.debug(dw_results_obj) | |
| results['dowhy_estimate'] = dw_results_obj.value | |
| results['dowhy_results_object'] = dw_results_obj | |
| method_used = 'dowhy' | |
| logger.info("DoWhy estimation successful.") | |
| except Exception as e: | |
| logger.error(f"DoWhy IV estimation failed: {e}", exc_info=True) | |
| results['dowhy_error'] = str(e) | |
| if not kwargs.get('allow_fallback', True): | |
| logger.warning("Fallback to statsmodels disabled. Estimation failed.") | |
| method_used = "dowhy_failed" | |
| # Still run diagnostics and format output | |
| else: | |
| logger.info("Proceeding to statsmodels fallback.") | |
| elif not model or not identified_estimand: | |
| logger.warning("Skipping DoWhy estimation due to CausalModel initialization/identification failure.") | |
| # Ensure we proceed to statsmodels if fallback is allowed | |
| if not kwargs.get('allow_fallback', True): | |
| logger.error("Cannot estimate effect: CausalModel failed and fallback disabled.") | |
| method_used = "dowhy_failed" | |
| else: | |
| logger.info("Proceeding to statsmodels fallback.") | |
| # --- Fallback Path: statsmodels IV2SLS --- | |
| if method_used not in ['dowhy', 'dowhy_failed']: | |
| logger.info("\nAttempting estimation with statsmodels IV2SLS...") | |
| try: | |
| df_copy = df.copy().dropna(subset=required_cols) | |
| if df_copy.empty: | |
| raise ValueError("DataFrame becomes empty after dropping NAs in required columns.") | |
| df_copy['intercept'] = 1 | |
| exog_regressors = ['intercept'] + clean_covariates | |
| endog_var = treatment | |
| all_instruments_for_sm = list(dict.fromkeys(exog_regressors + valid_instruments)) | |
| endog_data = df_copy[outcome] | |
| exog_data_sm_cols = list(dict.fromkeys(exog_regressors + [endog_var])) | |
| exog_data_sm = df_copy[exog_data_sm_cols] | |
| instrument_data_sm = df_copy[all_instruments_for_sm] | |
| num_endog = 1 | |
| num_external_iv = len(valid_instruments) | |
| if num_endog > num_external_iv: | |
| raise ValueError(f"Model underidentified: More endogenous regressors ({num_endog}) than unique external instruments ({num_external_iv}).") | |
| iv_model = IV2SLS(endog=endog_data, exog=exog_data_sm, instrument=instrument_data_sm) | |
| sm_results_obj = iv_model.fit() | |
| logger.info("\nStatsmodels Estimation Summary:") | |
| logger.info(f" Estimate for {treatment}: {sm_results_obj.params[treatment]}") | |
| logger.info(f" Std Error: {sm_results_obj.bse[treatment]}") | |
| logger.info(f" P-value: {sm_results_obj.pvalues[treatment]}") | |
| results['statsmodels_estimate'] = sm_results_obj.params[treatment] | |
| results['statsmodels_results_object'] = sm_results_obj | |
| method_used = 'statsmodels' | |
| logger.info("Statsmodels estimation successful.") | |
| except Exception as sm_e: | |
| logger.error(f"Statsmodels IV estimation also failed: {sm_e}", exc_info=True) | |
| results['statsmodels_error'] = str(sm_e) | |
| method_used = 'statsmodels_failed' if method_used == "none" else "dowhy_failed_sm_failed" | |
| # --- Diagnostics --- | |
| logger.info("\nRunning diagnostics...") | |
| diagnostics = run_iv_diagnostics(df, treatment, outcome, valid_instruments, clean_covariates, sm_results_obj, dw_results_obj) | |
| results['diagnostics'] = diagnostics | |
| # --- Refutation Step --- | |
| final_estimate_value = results.get('dowhy_estimate') if method_used == 'dowhy' else results.get('statsmodels_estimate') | |
| # Only run permute refuter if estimate is valid AND came from DoWhy | |
| if method_used == 'dowhy' and dw_results_obj and final_estimate_value is not None: | |
| logger.info("\nRunning refutation test (Placebo Treatment - Permute - requires DoWhy estimate object)...") | |
| try: | |
| # Pass the actual DoWhy estimate object | |
| refuter_result = model.refute_estimate( | |
| identified_estimand, | |
| dw_results_obj, # Pass the original DoWhy result object | |
| method_name="placebo_treatment_refuter", | |
| placebo_type="permute" # Necessary for IV according to docs/examples | |
| ) | |
| logger.info("Refutation test completed.") | |
| logger.debug(f"Refuter Result:\n{refuter_result}") | |
| # Store relevant info from refuter_result (check its structure) | |
| refutation_results = { | |
| "refuter": "placebo_treatment_refuter", | |
| "new_effect": getattr(refuter_result, 'new_effect', 'N/A'), | |
| "p_value": getattr(refuter_result, 'refutation_result', {}).get('p_value', 'N/A') if hasattr(refuter_result, 'refutation_result') else 'N/A', | |
| # Passed if p-value > 0.05 (or not statistically significant) | |
| "passed": getattr(refuter_result, 'refutation_result', {}).get('is_statistically_significant', None) == False if hasattr(refuter_result, 'refutation_result') else None | |
| } | |
| except Exception as refute_e: | |
| logger.error(f"Refutation test failed: {refute_e}", exc_info=True) | |
| refutation_results = {"error": f"Refutation failed: {refute_e}"} | |
| elif final_estimate_value is not None and method_used == 'statsmodels': | |
| logger.warning("Skipping placebo permutation refuter: Estimate was generated by statsmodels, not DoWhy's IV estimator.") | |
| refutation_results = {"status": "skipped_wrong_estimator_for_permute"} | |
| elif final_estimate_value is None: | |
| logger.warning("Skipping refutation test because estimation failed.") | |
| refutation_results = {"status": "skipped_due_to_failed_estimation"} | |
| else: # Model or estimand failed earlier, or unknown method_used | |
| logger.warning(f"Skipping refutation test due to earlier failure (method_used: {method_used}).") | |
| refutation_results = {"status": "skipped_due_to_model_failure_or_unknown"} | |
| results['refutation_results'] = refutation_results # Add to main results | |
| # --- Formatting Results --- | |
| if final_estimate_value is None and method_used not in ['dowhy', 'statsmodels']: | |
| logger.error("ERROR: Both estimation methods failed.") | |
| # Ensure error key exists if not set earlier | |
| if 'error' not in results: | |
| results['error'] = "Both DoWhy and statsmodels IV estimation failed." | |
| logger.info("\n--- Formatting Final Results ---") | |
| formatted_results = format_iv_results( | |
| final_estimate_value, # Pass the numeric value | |
| results, # Pass the dict containing estimate objects and refutation results | |
| diagnostics, | |
| treatment, | |
| outcome, | |
| valid_instruments, | |
| method_used, | |
| llm=llm | |
| ) | |
| logger.info("--- Instrumental Variable Estimation Complete ---\n") | |
| return formatted_results |