Spaces:
Running
Running
| """ | |
| Difference in Means / Simple Linear Regression Estimator. | |
| Estimates the Average Treatment Effect (ATE) by comparing the mean outcome | |
| between the treated and control groups. This is equivalent to a simple OLS | |
| regression of the outcome on the treatment indicator. | |
| Assumes no confounding (e.g., suitable for RCT data). | |
| """ | |
| import pandas as pd | |
| import statsmodels.api as sm | |
| import numpy as np | |
| import warnings | |
| from typing import Dict, Any, Optional | |
| import logging | |
| from langchain.chat_models.base import BaseChatModel # For type hinting llm | |
| from .diagnostics import run_dim_diagnostics | |
| from .llm_assist import interpret_dim_results | |
| logger = logging.getLogger(__name__) | |
| def estimate_effect( | |
| df: pd.DataFrame, | |
| treatment: str, | |
| outcome: str, | |
| query: Optional[str] = None, # For potential LLM use | |
| llm: Optional[BaseChatModel] = None, # For potential LLM use | |
| **kwargs # To capture any other potential arguments (e.g., covariates - which are ignored) | |
| ) -> Dict[str, Any]: | |
| """ | |
| Estimates the causal effect using Difference in Means (via OLS). | |
| Ignores any provided covariates. | |
| Args: | |
| df: Input DataFrame. | |
| treatment: Name of the binary treatment variable column (should be 0 or 1). | |
| outcome: Name of the outcome variable column. | |
| query: Optional user query for context. | |
| llm: Optional Language Model instance. | |
| **kwargs: Additional keyword arguments (ignored). | |
| Returns: | |
| Dictionary containing estimation results: | |
| - 'effect_estimate': The difference in means (treatment coefficient). | |
| - 'p_value': The p-value associated with the difference. | |
| - 'confidence_interval': The 95% confidence interval for the difference. | |
| - 'standard_error': The standard error of the difference. | |
| - 'formula': The regression formula used. | |
| - 'model_summary': Summary object from statsmodels. | |
| - 'diagnostics': Basic group statistics. | |
| - 'interpretation': LLM interpretation. | |
| """ | |
| required_cols = [treatment, outcome] | |
| missing_cols = [col for col in required_cols if col not in df.columns] | |
| if missing_cols: | |
| raise ValueError(f"Missing required columns: {missing_cols}") | |
| # Validate treatment is binary (or close to it) | |
| treat_vals = df[treatment].dropna().unique() | |
| if not np.all(np.isin(treat_vals, [0, 1])): | |
| warnings.warn(f"Treatment column '{treatment}' contains values other than 0 and 1: {treat_vals}. Proceeding, but results may be unreliable.", UserWarning) | |
| # Optional: could raise ValueError here if strict binary is required | |
| # Prepare data for statsmodels (add constant, handle potential NaNs) | |
| df_analysis = df[required_cols].dropna() | |
| if df_analysis.empty: | |
| raise ValueError("No data remaining after dropping NaNs for required columns.") | |
| X = df_analysis[[treatment]] | |
| X = sm.add_constant(X) # Add intercept | |
| y = df_analysis[outcome] | |
| formula = f"{outcome} ~ {treatment} + const" | |
| logger.info(f"Running Difference in Means regression: {formula}") | |
| try: | |
| model = sm.OLS(y, X) | |
| results = model.fit() | |
| effect_estimate = results.params[treatment] | |
| p_value = results.pvalues[treatment] | |
| conf_int = results.conf_int(alpha=0.05).loc[treatment].tolist() | |
| std_err = results.bse[treatment] | |
| # Run basic diagnostics (group means, stds, counts) | |
| diag_results = run_dim_diagnostics(df_analysis, treatment, outcome) | |
| # Get interpretation | |
| interpretation = interpret_dim_results(results, diag_results, treatment, llm=llm) | |
| return { | |
| 'effect_estimate': effect_estimate, | |
| 'p_value': p_value, | |
| 'confidence_interval': conf_int, | |
| 'standard_error': std_err, | |
| 'formula': formula, | |
| 'model_summary': results.summary(), | |
| 'diagnostics': diag_results, | |
| 'interpretation': interpretation, | |
| 'method_used': 'Difference in Means (OLS)' | |
| } | |
| except Exception as e: | |
| logger.error(f"Difference in Means failed: {e}") | |
| raise | |