""" Data Splitting Module for Training Pipeline. Provides utilities for: - Train/validation/test splitting - Stratified sampling by domain or difficulty - Cross-validation fold creation - Reproducible splits with seeding """ import logging from collections import defaultdict from dataclasses import dataclass from typing import Any from .dataset_loader import DatasetSample logger = logging.getLogger(__name__) @dataclass class DataSplit: """Result of dataset splitting.""" train: list[DatasetSample] validation: list[DatasetSample] test: list[DatasetSample] split_info: dict[str, Any] @dataclass class CrossValidationFold: """Single fold for cross-validation.""" fold_id: int train: list[DatasetSample] validation: list[DatasetSample] class DataSplitter: """ Basic dataset splitter with random sampling. Provides reproducible train/validation/test splits with configurable ratios. """ def __init__(self, seed: int = 42): """ Initialize splitter. Args: seed: Random seed for reproducibility """ self.seed = seed import random self.rng = random.Random(seed) def split( self, samples: list[DatasetSample], train_ratio: float = 0.7, val_ratio: float = 0.15, test_ratio: float = 0.15, shuffle: bool = True, ) -> DataSplit: """ Split dataset into train/validation/test sets. Args: samples: List of all samples train_ratio: Proportion for training (default 0.7) val_ratio: Proportion for validation (default 0.15) test_ratio: Proportion for testing (default 0.15) shuffle: Whether to shuffle before splitting Returns: DataSplit with train, validation, and test sets """ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001: raise ValueError("Ratios must sum to 1.0") if not samples: raise ValueError("Cannot split empty sample list") # Copy and optionally shuffle all_samples = list(samples) if shuffle: self.rng.shuffle(all_samples) n = len(all_samples) train_end = int(n * train_ratio) val_end = train_end + int(n * val_ratio) train_samples = all_samples[:train_end] val_samples = all_samples[train_end:val_end] test_samples = all_samples[val_end:] split_info = { "total_samples": n, "train_samples": len(train_samples), "val_samples": len(val_samples), "test_samples": len(test_samples), "train_ratio": len(train_samples) / n, "val_ratio": len(val_samples) / n, "test_ratio": len(test_samples) / n, "seed": self.seed, "shuffled": shuffle, } logger.info(f"Split {n} samples: train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}") return DataSplit( train=train_samples, validation=val_samples, test=test_samples, split_info=split_info, ) def create_k_folds( self, samples: list[DatasetSample], k: int = 5, shuffle: bool = True, ) -> list[CrossValidationFold]: """ Create k-fold cross-validation splits. Args: samples: List of all samples k: Number of folds shuffle: Whether to shuffle before splitting Returns: List of CrossValidationFold objects """ if k < 2: raise ValueError("k must be at least 2") if len(samples) < k: raise ValueError(f"Need at least {k} samples for {k}-fold CV") # Copy and optionally shuffle all_samples = list(samples) if shuffle: self.rng.shuffle(all_samples) # Calculate fold sizes fold_size = len(all_samples) // k folds = [] for fold_id in range(k): # Validation is the current fold val_start = fold_id * fold_size val_end = len(all_samples) if fold_id == k - 1 else val_start + fold_size # noqa: SIM108 val_samples = all_samples[val_start:val_end] train_samples = all_samples[:val_start] + all_samples[val_end:] folds.append( CrossValidationFold( fold_id=fold_id, train=train_samples, validation=val_samples, ) ) logger.info(f"Created {k}-fold cross-validation splits") return folds class StratifiedSplitter(DataSplitter): """ Stratified dataset splitter. Ensures proportional representation of categories (domain, difficulty, etc.) across splits. """ def __init__(self, seed: int = 42, stratify_by: str = "domain"): """ Initialize stratified splitter. Args: seed: Random seed for reproducibility stratify_by: Attribute to stratify on ('domain', 'difficulty', 'labels') """ super().__init__(seed) self.stratify_by = stratify_by def split( self, samples: list[DatasetSample], train_ratio: float = 0.7, val_ratio: float = 0.15, test_ratio: float = 0.15, shuffle: bool = True, ) -> DataSplit: """ Stratified split maintaining category proportions. Args: samples: List of all samples train_ratio: Proportion for training val_ratio: Proportion for validation test_ratio: Proportion for testing shuffle: Whether to shuffle before splitting Returns: DataSplit with stratified train, validation, and test sets """ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 0.001: raise ValueError("Ratios must sum to 1.0") if not samples: raise ValueError("Cannot split empty sample list") # Group samples by stratification key groups = defaultdict(list) for sample in samples: key = self._get_stratify_key(sample) groups[key].append(sample) # Split each group proportionally train_samples = [] val_samples = [] test_samples = [] for _key, group_samples in groups.items(): if shuffle: self.rng.shuffle(group_samples) n = len(group_samples) train_end = int(n * train_ratio) val_end = train_end + int(n * val_ratio) train_samples.extend(group_samples[:train_end]) val_samples.extend(group_samples[train_end:val_end]) test_samples.extend(group_samples[val_end:]) # Final shuffle of combined sets if shuffle: self.rng.shuffle(train_samples) self.rng.shuffle(val_samples) self.rng.shuffle(test_samples) # Verify stratification stratify_info = self._verify_stratification(train_samples, val_samples, test_samples) split_info = { "total_samples": len(samples), "train_samples": len(train_samples), "val_samples": len(val_samples), "test_samples": len(test_samples), "train_ratio": len(train_samples) / len(samples), "val_ratio": len(val_samples) / len(samples), "test_ratio": len(test_samples) / len(samples), "stratify_by": self.stratify_by, "stratification_info": stratify_info, "seed": self.seed, "shuffled": shuffle, } logger.info( f"Stratified split ({self.stratify_by}): " f"train={len(train_samples)}, val={len(val_samples)}, " f"test={len(test_samples)}" ) return DataSplit( train=train_samples, validation=val_samples, test=test_samples, split_info=split_info, ) def _get_stratify_key(self, sample: DatasetSample) -> str: """Get stratification key for a sample.""" if self.stratify_by == "domain": return sample.domain or "unknown" elif self.stratify_by == "difficulty": return sample.difficulty or "unknown" elif self.stratify_by == "labels": return ",".join(sorted(sample.labels)) if sample.labels else "unknown" else: return str(getattr(sample, self.stratify_by, "unknown")) def _verify_stratification( self, train: list[DatasetSample], val: list[DatasetSample], test: list[DatasetSample], ) -> dict[str, dict[str, float]]: """ Verify that stratification was successful. Returns dictionary showing distribution of stratification key across train/val/test splits. """ def get_distribution(samples: list[DatasetSample]) -> dict[str, float]: if not samples: return {} counts = defaultdict(int) for sample in samples: key = self._get_stratify_key(sample) counts[key] += 1 total = len(samples) return {k: v / total for k, v in counts.items()} return { "train": get_distribution(train), "validation": get_distribution(val), "test": get_distribution(test), } def create_stratified_k_folds( self, samples: list[DatasetSample], k: int = 5, shuffle: bool = True, ) -> list[CrossValidationFold]: """ Create stratified k-fold cross-validation splits. Args: samples: List of all samples k: Number of folds shuffle: Whether to shuffle before splitting Returns: List of CrossValidationFold objects with stratification """ if k < 2: raise ValueError("k must be at least 2") # Group samples by stratification key groups = defaultdict(list) for sample in samples: key = self._get_stratify_key(sample) groups[key].append(sample) # Initialize folds folds_data = [{"train": [], "val": []} for _ in range(k)] # Distribute each group across folds for _key, group_samples in groups.items(): if shuffle: self.rng.shuffle(group_samples) # Assign samples to folds fold_size = len(group_samples) // k for fold_id in range(k): val_start = fold_id * fold_size val_end = len(group_samples) if fold_id == k - 1 else val_start + fold_size for i, sample in enumerate(group_samples): if val_start <= i < val_end: folds_data[fold_id]["val"].append(sample) else: folds_data[fold_id]["train"].append(sample) # Create fold objects folds = [ CrossValidationFold( fold_id=i, train=data["train"], validation=data["val"], ) for i, data in enumerate(folds_data) ] logger.info(f"Created stratified {k}-fold cross-validation splits") return folds class BalancedSampler: """ Balanced sampling for imbalanced datasets. Provides utilities for: - Oversampling minority classes - Undersampling majority classes - SMOTE-like synthetic sampling (for numerical features) """ def __init__(self, seed: int = 42): """Initialize balanced sampler.""" self.seed = seed import random self.rng = random.Random(seed) def oversample_minority( self, samples: list[DatasetSample], target_key: str = "domain", target_ratio: float = 1.0, ) -> list[DatasetSample]: """ Oversample minority classes to balance dataset. Args: samples: Original samples target_key: Attribute to balance on target_ratio: Target ratio relative to majority (1.0 = equal) Returns: Balanced sample list (originals + oversampled) """ # Group by target key groups = defaultdict(list) for sample in samples: key = getattr(sample, target_key, "unknown") or "unknown" groups[key].append(sample) # Find majority class size max_count = max(len(g) for g in groups.values()) target_count = int(max_count * target_ratio) # Oversample minority classes balanced = [] for _key, group in groups.items(): balanced.extend(group) # Oversample if needed if len(group) < target_count: num_to_add = target_count - len(group) for _ in range(num_to_add): # Randomly duplicate from group original = self.rng.choice(group) duplicate = DatasetSample( id=f"{original.id}_oversample_{self.rng.randint(0, 999999)}", text=original.text, metadata={**original.metadata, "oversampled": True}, labels=original.labels, difficulty=original.difficulty, domain=original.domain, reasoning_steps=original.reasoning_steps, ) balanced.append(duplicate) logger.info(f"Oversampled from {len(samples)} to {len(balanced)} samples") return balanced def undersample_majority( self, samples: list[DatasetSample], target_key: str = "domain", target_ratio: float = 1.0, ) -> list[DatasetSample]: """ Undersample majority classes to balance dataset. Args: samples: Original samples target_key: Attribute to balance on target_ratio: Target ratio relative to minority (1.0 = equal) Returns: Balanced sample list (subset of originals) """ # Group by target key groups = defaultdict(list) for sample in samples: key = getattr(sample, target_key, "unknown") or "unknown" groups[key].append(sample) # Find minority class size min_count = min(len(g) for g in groups.values()) target_count = int(min_count * target_ratio) # Undersample majority classes balanced = [] for _key, group in groups.items(): if len(group) > target_count: # Randomly select target_count samples balanced.extend(self.rng.sample(group, target_count)) else: balanced.extend(group) logger.info(f"Undersampled from {len(samples)} to {len(balanced)} samples") return balanced def get_class_distribution( self, samples: list[DatasetSample], target_key: str = "domain", ) -> dict[str, int]: """ Get distribution of classes. Args: samples: Sample list target_key: Attribute to analyze Returns: Dictionary of class counts """ distribution = defaultdict(int) for sample in samples: key = getattr(sample, target_key, "unknown") or "unknown" distribution[key] += 1 return dict(distribution)