Spaces:
Running
Running
| """ | |
| Experiment Tracking Integration Module. | |
| Provides unified interface for: | |
| - Braintrust experiment tracking | |
| - Weights & Biases (W&B) logging | |
| - Metric collection and visualization | |
| - Model artifact versioning | |
| """ | |
| import logging | |
| import os | |
| import time | |
| from dataclasses import dataclass, field | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any | |
| logger = logging.getLogger(__name__) | |
| class ExperimentConfig: | |
| """Configuration for experiment tracking.""" | |
| project_name: str | |
| experiment_name: str | |
| tags: list[str] = field(default_factory=list) | |
| description: str = "" | |
| save_artifacts: bool = True | |
| log_frequency: int = 1 # Log every N steps | |
| class TrainingMetrics: | |
| """Standard training metrics.""" | |
| epoch: int | |
| step: int | |
| train_loss: float | |
| val_loss: float | None = None | |
| accuracy: float | None = None | |
| learning_rate: float | None = None | |
| timestamp: float = field(default_factory=time.time) | |
| custom_metrics: dict[str, float] = field(default_factory=dict) | |
| class BraintrustTracker: | |
| """ | |
| Braintrust experiment tracking integration. | |
| Provides: | |
| - Experiment initialization and management | |
| - Metric logging with automatic visualization | |
| - Hyperparameter tracking | |
| - Model evaluation scoring | |
| - Artifact versioning | |
| """ | |
| def __init__(self, api_key: str | None = None, project_name: str = "mcts-neural-meta-controller"): | |
| """ | |
| Initialize Braintrust tracker. | |
| Args: | |
| api_key: Braintrust API key (or from BRAINTRUST_API_KEY env var) | |
| project_name: Project name in Braintrust | |
| """ | |
| self.api_key = api_key or os.getenv("BRAINTRUST_API_KEY") | |
| self.project_name = project_name | |
| self._experiment = None | |
| self._experiment_id = None | |
| self._metrics_buffer: list[dict[str, Any]] = [] | |
| self._initialized = False | |
| if not self.api_key: | |
| logger.warning("BRAINTRUST_API_KEY not set. Using offline mode.") | |
| self._offline_mode = True | |
| else: | |
| self._offline_mode = False | |
| self._initialize_client() | |
| def _initialize_client(self): | |
| """Initialize Braintrust client.""" | |
| try: | |
| import braintrust | |
| braintrust.login(api_key=self.api_key) | |
| self._bt = braintrust | |
| self._initialized = True | |
| logger.info(f"Braintrust client initialized for project: {self.project_name}") | |
| except ImportError: | |
| logger.error("braintrust library not installed. Run: pip install braintrust") | |
| self._offline_mode = True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize Braintrust: {e}") | |
| self._offline_mode = True | |
| def init_experiment( | |
| self, | |
| name: str, | |
| description: str = "", | |
| tags: list[str] | None = None, | |
| metadata: dict[str, Any] | None = None, | |
| ) -> str: | |
| """ | |
| Initialize a new experiment. | |
| Args: | |
| name: Experiment name (e.g., "rnn_meta_controller_v2") | |
| description: Experiment description | |
| tags: List of tags for filtering | |
| metadata: Additional metadata | |
| Returns: | |
| Experiment ID | |
| """ | |
| if self._offline_mode: | |
| exp_id = f"offline_{int(time.time())}" | |
| logger.info(f"Created offline experiment: {exp_id}") | |
| self._experiment_id = exp_id | |
| self._experiment_config = { | |
| "name": name, | |
| "description": description, | |
| "tags": tags or [], | |
| "metadata": metadata or {}, | |
| "start_time": datetime.now().isoformat(), | |
| } | |
| return exp_id | |
| try: | |
| self._experiment = self._bt.init( | |
| project=self.project_name, | |
| experiment=name, | |
| ) | |
| self._experiment_id = self._experiment.id | |
| logger.info(f"Created Braintrust experiment: {name} (ID: {self._experiment_id})") | |
| return self._experiment_id | |
| except Exception as e: | |
| logger.error(f"Failed to create experiment: {e}") | |
| return self.init_experiment(name, description, tags, metadata) # Fallback to offline | |
| def log_hyperparameters(self, params: dict[str, Any]): | |
| """ | |
| Log hyperparameters for the experiment. | |
| Args: | |
| params: Dictionary of hyperparameters | |
| """ | |
| logger.info(f"Logging hyperparameters: {params}") | |
| if self._offline_mode: | |
| self._experiment_config["hyperparameters"] = params | |
| return | |
| try: | |
| if self._experiment: | |
| # Braintrust uses metadata for hyperparameters | |
| self._experiment.log( | |
| input="hyperparameters", | |
| output=params, | |
| metadata={"type": "hyperparameters"}, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to log hyperparameters: {e}") | |
| def log_metric( | |
| self, | |
| name: str, | |
| value: float, | |
| step: int | None = None, | |
| timestamp: float | None = None, | |
| ): | |
| """ | |
| Log a single metric. | |
| Args: | |
| name: Metric name | |
| value: Metric value | |
| step: Optional step number | |
| timestamp: Optional timestamp | |
| """ | |
| metric_data = { | |
| "name": name, | |
| "value": value, | |
| "step": step or len(self._metrics_buffer), | |
| "timestamp": timestamp or time.time(), | |
| } | |
| self._metrics_buffer.append(metric_data) | |
| if self._offline_mode: | |
| logger.debug(f"Metric logged (offline): {name}={value}") | |
| return | |
| try: | |
| if self._experiment: | |
| self._experiment.log( | |
| input=f"metric_{name}", | |
| output={"value": value}, | |
| scores={name: value}, | |
| metadata={"step": step}, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to log metric {name}: {e}") | |
| def log_training_step(self, metrics: TrainingMetrics): | |
| """ | |
| Log a complete training step. | |
| Args: | |
| metrics: TrainingMetrics object | |
| """ | |
| self.log_metric("train_loss", metrics.train_loss, step=metrics.step) | |
| if metrics.val_loss is not None: | |
| self.log_metric("val_loss", metrics.val_loss, step=metrics.step) | |
| if metrics.accuracy is not None: | |
| self.log_metric("accuracy", metrics.accuracy, step=metrics.step) | |
| if metrics.learning_rate is not None: | |
| self.log_metric("learning_rate", metrics.learning_rate, step=metrics.step) | |
| for key, value in metrics.custom_metrics.items(): | |
| self.log_metric(key, value, step=metrics.step) | |
| def log_evaluation( | |
| self, | |
| input_data: Any, | |
| output: Any, | |
| expected: Any, | |
| scores: dict[str, float], | |
| metadata: dict[str, Any] | None = None, | |
| ): | |
| """ | |
| Log an evaluation result. | |
| Args: | |
| input_data: Input to the model | |
| output: Model output | |
| expected: Expected output | |
| scores: Dictionary of scores (e.g., accuracy, f1) | |
| metadata: Additional metadata | |
| """ | |
| if self._offline_mode: | |
| logger.info(f"Evaluation logged (offline): scores={scores}") | |
| return | |
| try: | |
| if self._experiment: | |
| self._experiment.log( | |
| input=input_data, | |
| output=output, | |
| expected=expected, | |
| scores=scores, | |
| metadata=metadata or {}, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to log evaluation: {e}") | |
| def log_artifact(self, path: str | Path, name: str | None = None): | |
| """ | |
| Log a model artifact. | |
| Args: | |
| path: Path to artifact file | |
| name: Optional artifact name | |
| """ | |
| path = Path(path) | |
| if not path.exists(): | |
| logger.warning(f"Artifact not found: {path}") | |
| return | |
| logger.info(f"Logging artifact: {path}") | |
| if self._offline_mode: | |
| if "artifacts" not in self._experiment_config: | |
| self._experiment_config["artifacts"] = [] | |
| self._experiment_config["artifacts"].append(str(path)) | |
| return | |
| # Braintrust artifact logging would go here | |
| # For now, just log the path | |
| try: | |
| if self._experiment: | |
| self._experiment.log( | |
| input=f"artifact_{name or path.name}", | |
| output={"path": str(path), "name": name or path.name}, | |
| metadata={"artifact_path": str(path), "artifact_name": name or path.name}, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to log artifact: {e}") | |
| def get_summary(self) -> dict[str, Any]: | |
| """ | |
| Get experiment summary. | |
| Returns: | |
| Dictionary with experiment summary | |
| """ | |
| if self._offline_mode: | |
| return { | |
| "id": self._experiment_id, | |
| "config": self._experiment_config, | |
| "metrics_count": len(self._metrics_buffer), | |
| "offline": True, | |
| } | |
| return { | |
| "id": self._experiment_id, | |
| "project": self.project_name, | |
| "metrics_count": len(self._metrics_buffer), | |
| "offline": False, | |
| } | |
| def end_experiment(self): | |
| """End the current experiment.""" | |
| summary = self.get_summary() | |
| logger.info(f"Experiment ended: {summary}") | |
| if not self._offline_mode and self._experiment: | |
| try: | |
| # Braintrust experiments auto-close, but we'll try explicit close if available | |
| if hasattr(self._experiment, "close"): | |
| self._experiment.close() | |
| elif hasattr(self._experiment, "flush"): | |
| self._experiment.flush() | |
| except Exception as e: | |
| logger.error(f"Failed to end experiment: {e}") | |
| self._experiment = None | |
| self._experiment_id = None | |
| self._metrics_buffer = [] | |
| return summary | |
| class WandBTracker: | |
| """ | |
| Weights & Biases experiment tracking integration. | |
| Provides: | |
| - Real-time metric visualization | |
| - Hyperparameter sweep management | |
| - Model artifact logging | |
| - Collaborative experiment comparison | |
| """ | |
| def __init__( | |
| self, | |
| api_key: str | None = None, | |
| project_name: str = "mcts-neural-meta-controller", | |
| entity: str | None = None, | |
| ): | |
| """ | |
| Initialize W&B tracker. | |
| Args: | |
| api_key: W&B API key (or from WANDB_API_KEY env var) | |
| project_name: Project name in W&B | |
| entity: W&B entity (team or username) | |
| """ | |
| self.api_key = api_key or os.getenv("WANDB_API_KEY") | |
| self.project_name = project_name | |
| self.entity = entity | |
| self._run = None | |
| self._initialized = False | |
| self._offline_mode = os.getenv("WANDB_MODE") == "offline" | |
| if not self.api_key and not self._offline_mode: | |
| logger.warning("WANDB_API_KEY not set. Using offline mode.") | |
| self._offline_mode = True | |
| os.environ["WANDB_MODE"] = "offline" | |
| else: | |
| self._initialize_client() | |
| def _initialize_client(self): | |
| """Initialize W&B client.""" | |
| try: | |
| import wandb | |
| if self.api_key: | |
| wandb.login(key=self.api_key) | |
| self._wandb = wandb | |
| self._initialized = True | |
| logger.info(f"W&B client initialized for project: {self.project_name}") | |
| except ImportError: | |
| logger.error("wandb library not installed. Run: pip install wandb") | |
| self._offline_mode = True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize W&B: {e}") | |
| self._offline_mode = True | |
| def init_run( | |
| self, | |
| name: str, | |
| config: dict[str, Any] | None = None, | |
| tags: list[str] | None = None, | |
| notes: str = "", | |
| ): | |
| """ | |
| Initialize a new W&B run. | |
| Args: | |
| name: Run name | |
| config: Configuration dictionary | |
| tags: List of tags | |
| notes: Run notes/description | |
| Returns: | |
| Run object | |
| """ | |
| if self._offline_mode: | |
| logger.info(f"W&B run initialized (offline mode): {name}") | |
| self._run_config = config or {} | |
| return None | |
| try: | |
| self._run = self._wandb.init( | |
| project=self.project_name, | |
| entity=self.entity, | |
| name=name, | |
| config=config, | |
| tags=tags, | |
| notes=notes, | |
| ) | |
| logger.info(f"W&B run initialized: {name}") | |
| return self._run | |
| except Exception as e: | |
| logger.error(f"Failed to initialize W&B run: {e}") | |
| self._offline_mode = True | |
| return None | |
| def log(self, metrics: dict[str, Any], step: int | None = None): | |
| """ | |
| Log metrics to W&B. | |
| Args: | |
| metrics: Dictionary of metrics | |
| step: Optional step number | |
| """ | |
| if self._offline_mode: | |
| logger.debug(f"W&B metrics (offline): {metrics}") | |
| return | |
| try: | |
| if self._run: | |
| self._wandb.log(metrics, step=step) | |
| except Exception as e: | |
| logger.error(f"Failed to log to W&B: {e}") | |
| def log_training_step(self, metrics: TrainingMetrics): | |
| """ | |
| Log a complete training step to W&B. | |
| Args: | |
| metrics: TrainingMetrics object | |
| """ | |
| log_data = { | |
| "epoch": metrics.epoch, | |
| "train_loss": metrics.train_loss, | |
| } | |
| if metrics.val_loss is not None: | |
| log_data["val_loss"] = metrics.val_loss | |
| if metrics.accuracy is not None: | |
| log_data["accuracy"] = metrics.accuracy | |
| if metrics.learning_rate is not None: | |
| log_data["learning_rate"] = metrics.learning_rate | |
| log_data.update(metrics.custom_metrics) | |
| self.log(log_data, step=metrics.step) | |
| def update_config(self, config: dict[str, Any]): | |
| """ | |
| Update run configuration. | |
| Args: | |
| config: Configuration updates | |
| """ | |
| if self._offline_mode: | |
| self._run_config.update(config) | |
| return | |
| try: | |
| if self._run: | |
| self._wandb.config.update(config) | |
| except Exception as e: | |
| logger.error(f"Failed to update W&B config: {e}") | |
| def watch_model(self, model, log_freq: int = 100): | |
| """ | |
| Watch model gradients and parameters. | |
| Args: | |
| model: PyTorch model | |
| log_freq: Logging frequency | |
| """ | |
| if self._offline_mode: | |
| return | |
| try: | |
| if self._run: | |
| self._wandb.watch(model, log="all", log_freq=log_freq) | |
| except Exception as e: | |
| logger.error(f"Failed to watch model: {e}") | |
| def log_artifact(self, path: str | Path, name: str, artifact_type: str = "model"): | |
| """ | |
| Log artifact to W&B. | |
| Args: | |
| path: Path to artifact | |
| name: Artifact name | |
| artifact_type: Type of artifact (model, dataset, etc.) | |
| """ | |
| if self._offline_mode: | |
| logger.info(f"Artifact logged (offline): {path}") | |
| return | |
| try: | |
| artifact = self._wandb.Artifact(name, type=artifact_type) | |
| artifact.add_file(str(path)) | |
| if self._run: | |
| self._run.log_artifact(artifact) | |
| logger.info(f"Artifact logged: {name}") | |
| except Exception as e: | |
| logger.error(f"Failed to log artifact: {e}") | |
| def finish(self): | |
| """Finish the W&B run.""" | |
| if self._offline_mode: | |
| logger.info("W&B run finished (offline)") | |
| return | |
| try: | |
| if self._run: | |
| self._run.finish() | |
| logger.info("W&B run finished") | |
| except Exception as e: | |
| logger.error(f"Failed to finish W&B run: {e}") | |
| class UnifiedExperimentTracker: | |
| """ | |
| Unified experiment tracker that coordinates both Braintrust and W&B. | |
| Provides single interface for: | |
| - Dual logging to both platforms | |
| - Fallback handling | |
| - Consistent metric tracking | |
| """ | |
| def __init__( | |
| self, | |
| braintrust_api_key: str | None = None, | |
| wandb_api_key: str | None = None, | |
| project_name: str = "mcts-neural-meta-controller", | |
| ): | |
| """ | |
| Initialize unified tracker. | |
| Args: | |
| braintrust_api_key: Braintrust API key | |
| wandb_api_key: W&B API key | |
| project_name: Project name for both platforms | |
| """ | |
| self.bt = BraintrustTracker(api_key=braintrust_api_key, project_name=project_name) | |
| self.wandb = WandBTracker(api_key=wandb_api_key, project_name=project_name) | |
| self.project_name = project_name | |
| def init_experiment( | |
| self, | |
| name: str, | |
| config: dict[str, Any] | None = None, | |
| description: str = "", | |
| tags: list[str] | None = None, | |
| ): | |
| """ | |
| Initialize experiment on both platforms. | |
| Args: | |
| name: Experiment/run name | |
| config: Configuration dictionary | |
| description: Description | |
| tags: List of tags | |
| """ | |
| self.bt.init_experiment(name, description, tags) | |
| self.wandb.init_run(name, config, tags, description) | |
| if config: | |
| self.bt.log_hyperparameters(config) | |
| logger.info(f"Unified experiment initialized: {name}") | |
| def log_metrics(self, metrics: TrainingMetrics): | |
| """ | |
| Log training metrics to both platforms. | |
| Args: | |
| metrics: TrainingMetrics object | |
| """ | |
| self.bt.log_training_step(metrics) | |
| self.wandb.log_training_step(metrics) | |
| def log_evaluation( | |
| self, | |
| input_data: Any, | |
| output: Any, | |
| expected: Any, | |
| scores: dict[str, float], | |
| ): | |
| """ | |
| Log evaluation to Braintrust. | |
| Args: | |
| input_data: Input data | |
| output: Model output | |
| expected: Expected output | |
| scores: Evaluation scores | |
| """ | |
| self.bt.log_evaluation(input_data, output, expected, scores) | |
| self.wandb.log(scores) | |
| def log_artifact(self, path: str | Path, name: str): | |
| """ | |
| Log artifact to both platforms. | |
| Args: | |
| path: Path to artifact | |
| name: Artifact name | |
| """ | |
| self.bt.log_artifact(path, name) | |
| self.wandb.log_artifact(path, name) | |
| def finish(self): | |
| """End tracking on both platforms.""" | |
| bt_summary = self.bt.end_experiment() | |
| self.wandb.finish() | |
| logger.info("Unified experiment ended") | |
| return bt_summary | |