""" Dataset Loading Module for Open-Source Training Data. Provides unified loading interfaces for: - DABStep: Multi-step data analysis reasoning - PRIMUS: Cybersecurity domain knowledge - Custom tactical datasets License Attribution: - DABStep: CC-BY-4.0 (Creative Commons Attribution) - PRIMUS: ODC-BY (Open Data Commons Attribution) """ import logging from abc import ABC, abstractmethod from collections.abc import Iterator from dataclasses import dataclass, field from pathlib import Path from typing import Any logger = logging.getLogger(__name__) @dataclass class DatasetSample: """Standardized representation of a dataset sample.""" id: str text: str metadata: dict[str, Any] = field(default_factory=dict) labels: list[str] | None = None difficulty: str | None = None domain: str | None = None reasoning_steps: list[str] | None = None @dataclass class DatasetStatistics: """Statistics about a loaded dataset.""" total_samples: int domains: dict[str, int] avg_text_length: float difficulty_distribution: dict[str, int] total_tokens: int = 0 class DatasetLoader(ABC): """Abstract base class for dataset loaders.""" def __init__(self, cache_dir: str | None = None): """ Initialize dataset loader. Args: cache_dir: Directory to cache downloaded datasets """ self.cache_dir = cache_dir or str(Path.home() / ".cache" / "mcts_datasets") self._dataset = None self._statistics = None @abstractmethod def load(self, split: str = "train") -> list[DatasetSample]: """Load dataset split.""" pass @abstractmethod def get_statistics(self) -> DatasetStatistics: """Get dataset statistics.""" pass @abstractmethod def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]: """Iterate over samples in batches.""" pass class DABStepLoader(DatasetLoader): """ Loader for DABStep Multi-Step Reasoning Dataset. DABStep contains 450+ data analysis tasks requiring sequential, iterative problem-solving. Perfect for training HRM/TRM agents. License: CC-BY-4.0 (Attribution required) Source: huggingface.co/datasets/adyen/DABstep """ DATASET_NAME = "adyen/DABstep" DIFFICULTIES = ["easy", "medium", "hard"] def __init__(self, cache_dir: str | None = None): """Initialize DABStep loader.""" super().__init__(cache_dir) self._loaded_samples: list[DatasetSample] = [] def load(self, split: str = "train", difficulty: str | None = None) -> list[DatasetSample]: """ Load DABStep dataset. Args: split: Dataset split ('train', 'validation', 'test') difficulty: Filter by difficulty ('easy', 'medium', 'hard') Returns: List of DatasetSample objects """ try: from datasets import load_dataset logger.info(f"Loading DABStep dataset (split={split})") dataset = load_dataset( self.DATASET_NAME, cache_dir=self.cache_dir, ) if split not in dataset: available_splits = list(dataset.keys()) logger.warning(f"Split '{split}' not found. Available: {available_splits}") split = available_splits[0] if available_splits else "train" samples = [] for idx, item in enumerate(dataset[split]): sample = DatasetSample( id=f"dabstep_{split}_{idx}", text=str(item.get("question", item.get("text", ""))), metadata={ "source": "DABStep", "license": "CC-BY-4.0", "split": split, "original_data": item, }, difficulty=item.get("difficulty", "medium"), domain="data_analysis", reasoning_steps=item.get("steps", []), ) if difficulty and sample.difficulty != difficulty: continue samples.append(sample) self._loaded_samples = samples logger.info(f"Loaded {len(samples)} DABStep samples") return samples except ImportError: logger.error("datasets library not installed. Run: pip install datasets") raise except Exception as e: logger.error(f"Failed to load DABStep: {e}") raise def get_statistics(self) -> DatasetStatistics: """Get statistics about loaded DABStep data.""" if not self._loaded_samples: raise ValueError("No samples loaded. Call load() first.") difficulty_dist = {} total_length = 0 for sample in self._loaded_samples: diff = sample.difficulty or "unknown" difficulty_dist[diff] = difficulty_dist.get(diff, 0) + 1 total_length += len(sample.text) return DatasetStatistics( total_samples=len(self._loaded_samples), domains={"data_analysis": len(self._loaded_samples)}, avg_text_length=total_length / len(self._loaded_samples), difficulty_distribution=difficulty_dist, ) def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]: """Iterate over samples in batches.""" if not self._loaded_samples: raise ValueError("No samples loaded. Call load() first.") for i in range(0, len(self._loaded_samples), batch_size): yield self._loaded_samples[i : i + batch_size] def get_reasoning_tasks(self) -> list[DatasetSample]: """Get only samples with explicit reasoning steps.""" return [s for s in self._loaded_samples if s.reasoning_steps] class PRIMUSLoader(DatasetLoader): """ Loader for PRIMUS Cybersecurity Dataset Suite. PRIMUS contains: - Seed: 674,848 cybersecurity documents (190M tokens) - Instruct: 835 instruction-tuning samples - Reasoning: Self-reflection data for reasoning License: ODC-BY (Open Data Commons Attribution) Source: huggingface.co/datasets/trendmicro-ailab/Primus-Seed """ SEED_DATASET = "trendmicro-ailab/Primus-Seed" INSTRUCT_DATASET = "trendmicro-ailab/Primus-Instruct" DOMAINS = [ "mitre_attack", "wikipedia", "company_sites", "threat_intelligence", "vulnerability_db", ] def __init__(self, cache_dir: str | None = None): """Initialize PRIMUS loader.""" super().__init__(cache_dir) self._seed_samples: list[DatasetSample] = [] self._instruct_samples: list[DatasetSample] = [] def load( self, split: str = "train", dataset_type: str = "seed", domains: list[str] | None = None, max_samples: int | None = None, streaming: bool = True, ) -> list[DatasetSample]: """ Load PRIMUS dataset. Args: split: Dataset split ('train', 'validation', 'test') dataset_type: 'seed' for knowledge base, 'instruct' for fine-tuning domains: Filter by specific domains max_samples: Limit number of samples (useful for large datasets) streaming: Use streaming mode for large datasets (default True) Returns: List of DatasetSample objects """ try: from datasets import load_dataset dataset_name = self.SEED_DATASET if dataset_type == "seed" else self.INSTRUCT_DATASET logger.info(f"Loading PRIMUS {dataset_type} dataset") # Use streaming for large seed dataset to avoid download issues use_streaming = streaming and dataset_type == "seed" and max_samples is not None if use_streaming: logger.info(f"Using streaming mode (max_samples={max_samples})") dataset = load_dataset( dataset_name, "default", streaming=True, cache_dir=self.cache_dir, ) # For streaming, iterate the first available split data_iter = iter(dataset["train"]) if "train" in dataset else iter(dataset[list(dataset.keys())[0]]) else: dataset = load_dataset( dataset_name, cache_dir=self.cache_dir, ) if split not in dataset: available_splits = list(dataset.keys()) logger.warning(f"Split '{split}' not found. Using: {available_splits[0]}") split = available_splits[0] data_iter = iter(dataset[split]) samples = [] count = 0 for idx, item in enumerate(data_iter): if max_samples and count >= max_samples: break domain = item.get("domain", item.get("source", "unknown")) if domains and domain not in domains: continue if dataset_type == "instruct": text = f"Instruction: {item.get('instruction', '')}\nResponse: {item.get('response', '')}" else: text = str(item.get("text", item.get("content", ""))) sample = DatasetSample( id=f"primus_{dataset_type}_{split}_{idx}", text=text, metadata={ "source": f"PRIMUS-{dataset_type.capitalize()}", "license": "ODC-BY", "split": split, "original_domain": domain, }, domain=domain, labels=item.get("labels", item.get("tags", [])), ) samples.append(sample) count += 1 if dataset_type == "seed": self._seed_samples = samples else: self._instruct_samples = samples logger.info(f"Loaded {len(samples)} PRIMUS {dataset_type} samples") return samples except ImportError: logger.error("datasets library not installed. Run: pip install datasets") raise except Exception as e: if "gated dataset" in str(e): logger.error( f"PRIMUS is a gated dataset. Please authenticate with HuggingFace:\n" f"1. Create account at https://huggingface.co/\n" f"2. Accept dataset terms at https://huggingface.co/datasets/{dataset_name}\n" f"3. Create token at https://huggingface.co/settings/tokens\n" f"4. Run: huggingface-cli login" ) else: logger.error(f"Failed to load PRIMUS: {e}") raise def load_seed(self, max_samples: int | None = None) -> list[DatasetSample]: """Load PRIMUS-Seed knowledge base.""" return self.load(dataset_type="seed", max_samples=max_samples) def load_instruct(self) -> list[DatasetSample]: """Load PRIMUS-Instruct fine-tuning data.""" return self.load(dataset_type="instruct", streaming=False) def get_statistics(self) -> DatasetStatistics: """Get statistics about loaded PRIMUS data.""" all_samples = self._seed_samples + self._instruct_samples if not all_samples: raise ValueError("No samples loaded. Call load() first.") domain_dist = {} total_length = 0 for sample in all_samples: domain = sample.domain or "unknown" domain_dist[domain] = domain_dist.get(domain, 0) + 1 total_length += len(sample.text) return DatasetStatistics( total_samples=len(all_samples), domains=domain_dist, avg_text_length=total_length / len(all_samples), difficulty_distribution={"cybersecurity": len(all_samples)}, ) def iterate_samples(self, batch_size: int = 32) -> Iterator[list[DatasetSample]]: """Iterate over all loaded samples in batches.""" all_samples = self._seed_samples + self._instruct_samples if not all_samples: raise ValueError("No samples loaded. Call load() first.") for i in range(0, len(all_samples), batch_size): yield all_samples[i : i + batch_size] def get_mitre_attack_samples(self) -> list[DatasetSample]: """Get samples specifically from MITRE ATT&CK.""" return [s for s in self._seed_samples if "mitre" in (s.domain or "").lower()] def get_threat_intelligence_samples(self) -> list[DatasetSample]: """Get threat intelligence related samples.""" return [ s for s in self._seed_samples if any(kw in (s.domain or "").lower() for kw in ["threat", "cti", "intelligence"]) ] class CombinedDatasetLoader: """ Unified loader for combining multiple datasets. Provides a single interface for loading and managing: - DABStep (multi-step reasoning) - PRIMUS (cybersecurity knowledge) - Custom tactical datasets """ def __init__(self, cache_dir: str | None = None): """Initialize combined loader.""" self.cache_dir = cache_dir self.dabstep_loader = DABStepLoader(cache_dir) self.primus_loader = PRIMUSLoader(cache_dir) self._all_samples: list[DatasetSample] = [] def load_all( self, dabstep_split: str = "train", primus_max_samples: int | None = 10000, include_instruct: bool = True, ) -> list[DatasetSample]: """ Load all datasets. Args: dabstep_split: Split for DABStep primus_max_samples: Max samples from PRIMUS-Seed (None for all) include_instruct: Whether to include PRIMUS-Instruct Returns: Combined list of all samples """ logger.info("Loading combined datasets") # Load DABStep dabstep_samples = self.dabstep_loader.load(split=dabstep_split) logger.info(f"DABStep: {len(dabstep_samples)} samples") # Load PRIMUS-Seed primus_seed = self.primus_loader.load_seed(max_samples=primus_max_samples) logger.info(f"PRIMUS-Seed: {len(primus_seed)} samples") # Load PRIMUS-Instruct primus_instruct = [] if include_instruct: primus_instruct = self.primus_loader.load_instruct() logger.info(f"PRIMUS-Instruct: {len(primus_instruct)} samples") self._all_samples = dabstep_samples + primus_seed + primus_instruct logger.info(f"Total combined samples: {len(self._all_samples)}") return self._all_samples def get_domain_distribution(self) -> dict[str, int]: """Get distribution of samples across domains.""" dist = {} for sample in self._all_samples: domain = sample.domain or "unknown" dist[domain] = dist.get(domain, 0) + 1 return dist def filter_by_domain(self, domain: str) -> list[DatasetSample]: """Filter samples by domain.""" return [s for s in self._all_samples if s.domain == domain] def get_multi_step_reasoning_samples(self) -> list[DatasetSample]: """Get samples suitable for multi-step reasoning training.""" return [ s for s in self._all_samples if s.reasoning_steps or s.domain == "data_analysis" or "instruct" in s.metadata.get("source", "").lower() ] def export_for_training(self, output_path: str, format: str = "jsonl") -> str: """ Export dataset for training. Args: output_path: Path to save exported data format: Export format ('jsonl', 'csv', 'parquet') Returns: Path to exported file """ import json output_file = Path(output_path) output_file.parent.mkdir(parents=True, exist_ok=True) if format == "jsonl": with open(output_file, "w", encoding="utf-8") as f: for sample in self._all_samples: record = { "id": sample.id, "text": sample.text, "domain": sample.domain, "difficulty": sample.difficulty, "labels": sample.labels, "metadata": sample.metadata, } f.write(json.dumps(record) + "\n") else: raise NotImplementedError(f"Format {format} not yet supported") logger.info(f"Exported {len(self._all_samples)} samples to {output_file}") return str(output_file) def load_dataset( dataset_name: str, split: str = "train", cache_dir: str | None = None, **kwargs, ) -> Any: """ Unified interface for loading datasets from HuggingFace. This function provides compatibility with the standard HuggingFace datasets API. It wraps the underlying load_dataset function from the datasets library. Args: dataset_name: HuggingFace dataset identifier (e.g., "adyen/DABstep") split: Dataset split to load ("train", "validation", "test") cache_dir: Optional directory for caching downloaded datasets **kwargs: Additional arguments passed to datasets.load_dataset Returns: HuggingFace Dataset object or dict of Dataset objects Raises: ImportError: If datasets library is not installed Exception: If dataset loading fails Examples: >>> # Load DABStep dataset >>> dataset = load_dataset("adyen/DABstep") >>> samples = dataset["train"] >>> # Load PRIMUS-Seed with custom cache >>> dataset = load_dataset("trendmicro-ailab/Primus-Seed", cache_dir="/tmp/cache") License Attribution: - DABStep: CC-BY-4.0 (Creative Commons Attribution 4.0) - PRIMUS: ODC-BY (Open Data Commons Attribution) """ try: from datasets import load_dataset as hf_load_dataset logger.info(f"Loading dataset: {dataset_name} (split={split})") load_kwargs = { **kwargs, } if cache_dir: load_kwargs["cache_dir"] = cache_dir dataset = hf_load_dataset(dataset_name, **load_kwargs) logger.info(f"Successfully loaded dataset: {dataset_name}") return dataset except ImportError: logger.error("datasets library not installed. Run: pip install datasets") raise ImportError("The datasets library is required but not installed. Install it with: pip install datasets") except Exception as e: logger.error(f"Failed to load dataset {dataset_name}: {e}") raise