Spaces:
Running
Running
| import pytest | |
| import os | |
| import pandas as pd | |
| # Import the refactored parse_input function | |
| from auto_causal.components import input_parser | |
| # Check if OpenAI API key is available, skip if not | |
| api_key_present = bool(os.environ.get("OPENAI_API_KEY")) | |
| skip_if_no_key = pytest.mark.skipif(not api_key_present, reason="OPENAI_API_KEY environment variable not set") | |
| def test_parse_input_with_real_llm(): | |
| """Tests the parse_input function invoking the actual LLM. | |
| Note: This test requires the OPENAI_API_KEY environment variable to be set | |
| and will make a real API call. | |
| """ | |
| # --- Test Case 1: Effect query with dataset and constraint --- | |
| query1 = "analyze the effect of 'Minimum Wage Increase' on 'Unemployment Rate' using data/county_data.csv where year > 2010" | |
| # Provide some dummy dataset context | |
| dataset_info1 = { | |
| 'columns': ['County', 'Year', 'Minimum Wage Increase', 'Unemployment Rate', 'Population'], | |
| 'column_types': {'County': 'object', 'Year': 'int64', 'Minimum Wage Increase': 'int64', 'Unemployment Rate': 'float64', 'Population': 'int64'}, | |
| 'sample_rows': [ | |
| {'County': 'A', 'Year': 2009, 'Minimum Wage Increase': 0, 'Unemployment Rate': 5.5, 'Population': 10000}, | |
| {'County': 'A', 'Year': 2011, 'Minimum Wage Increase': 1, 'Unemployment Rate': 6.0, 'Population': 10200} | |
| ] | |
| } | |
| # Create a dummy data file for path checking (relative to workspace root) | |
| dummy_file_path = "data/county_data.csv" | |
| os.makedirs(os.path.dirname(dummy_file_path), exist_ok=True) | |
| with open(dummy_file_path, 'w') as f: | |
| f.write("County,Year,Minimum Wage Increase,Unemployment Rate,Population\n") | |
| f.write("A,2009,0,5.5,10000\n") | |
| f.write("A,2011,1,6.0,10200\n") | |
| result1 = input_parser.parse_input(query=query1, dataset_info=dataset_info1) | |
| # Clean up dummy file | |
| if os.path.exists(dummy_file_path): | |
| os.remove(dummy_file_path) | |
| # Try removing the directory if empty | |
| try: | |
| os.rmdir(os.path.dirname(dummy_file_path)) | |
| except OSError: | |
| pass # Ignore if directory is not empty or other error | |
| # Assertions for Test Case 1 | |
| assert result1 is not None | |
| assert result1['original_query'] == query1 | |
| assert result1['query_type'] == "EFFECT_ESTIMATION" | |
| assert result1['dataset_path'] == dummy_file_path # Check if path extraction worked | |
| # Check variables (allowing for some LLM interpretation flexibility) | |
| assert 'treatment' in result1['extracted_variables'] | |
| assert 'outcome' in result1['extracted_variables'] | |
| # Check if the core variable names are present in the extracted lists | |
| assert any('Minimum Wage Increase' in t for t in result1['extracted_variables'].get('treatment', [])) | |
| assert any('Unemployment Rate' in o for o in result1['extracted_variables'].get('outcome', [])) | |
| # Check constraints | |
| assert isinstance(result1['constraints'], list) | |
| # Check if a constraint related to 'year > 2010' was captured (LLM might phrase it differently) | |
| assert any('year' in c.lower() and '2010' in c for c in result1.get('constraints', [])), "Constraint 'year > 2010' not found or not parsed correctly." | |
| # --- Test Case 2: Counterfactual without dataset path --- | |
| query2 = "What would sales have been if we hadn't run the 'Summer Sale' campaign?" | |
| dataset_info2 = { | |
| 'columns': ['Date', 'Sales', 'Summer Sale', 'Competitor Activity'], | |
| 'column_types': { 'Date': 'datetime64[ns]', 'Sales': 'float64', 'Summer Sale': 'int64', 'Competitor Activity': 'float64'} | |
| } | |
| result2 = input_parser.parse_input(query=query2, dataset_info=dataset_info2) | |
| # Assertions for Test Case 2 | |
| assert result2 is not None | |
| assert result2['query_type'] == "COUNTERFACTUAL" | |
| assert result2['dataset_path'] is None # No path mentioned or inferrable here | |
| assert any('Summer Sale' in t for t in result2['extracted_variables'].get('treatment', [])) | |
| assert any('Sales' in o for o in result2['extracted_variables'].get('outcome', [])) | |
| assert not result2['constraints'] # No constraints expected | |
| # --- Test Case 3: Simple query, LLM might fail validation? --- | |
| # This tests if the retry/failure mechanism logs warnings but doesn't crash | |
| # (Assuming LLM might struggle to extract treatment/outcome from just "sales vs ads") | |
| query3 = "sales vs ads" | |
| dataset_info3 = { | |
| 'columns': ['sales', 'ads'], | |
| 'column_types': {'sales': 'float', 'ads': 'float'} | |
| } | |
| result3 = input_parser.parse_input(query=query3, dataset_info=dataset_info3) | |
| assert result3 is not None | |
| # LLM might fail extraction; check default/fallback values | |
| # Query type might default to OTHER or CORRELATION/DESCRIPTIVE | |
| # Variables might be empty or partially filled | |
| # This mainly checks that the function completes without error even if LLM fails | |
| print(f"Result for ambiguous query: {result3}") |