Spaces:
Running
Running
| """ | |
| Unified Training Orchestrator for LangGraph Multi-Agent MCTS with DeepMind-Style Learning. | |
| Coordinates: | |
| - HRM Agent | |
| - TRM Agent | |
| - Neural MCTS | |
| - Policy-Value Network | |
| - Self-play data generation | |
| - Training loops | |
| - Evaluation | |
| - Checkpointing | |
| """ | |
| import time | |
| from collections.abc import Callable | |
| from pathlib import Path | |
| from typing import Any | |
| import torch | |
| import torch.nn as nn | |
| from torch.cuda.amp import GradScaler, autocast | |
| from ..agents.hrm_agent import HRMLoss, create_hrm_agent | |
| from ..agents.trm_agent import TRMLoss, create_trm_agent | |
| from ..framework.mcts.neural_mcts import GameState, NeuralMCTS, SelfPlayCollector | |
| from ..models.policy_value_net import ( | |
| AlphaZeroLoss, | |
| create_policy_value_network, | |
| ) | |
| from .performance_monitor import PerformanceMonitor, TimingContext | |
| from .replay_buffer import Experience, PrioritizedReplayBuffer, collate_experiences | |
| from .system_config import SystemConfig | |
| class UnifiedTrainingOrchestrator: | |
| """ | |
| Complete training pipeline integrating all framework components. | |
| This orchestrator manages: | |
| 1. Self-play data generation using MCTS | |
| 2. Neural network training (policy-value) | |
| 3. HRM agent training | |
| 4. TRM agent training | |
| 5. Evaluation and checkpointing | |
| 6. Performance monitoring | |
| """ | |
| def __init__( | |
| self, | |
| config: SystemConfig, | |
| initial_state_fn: Callable[[], GameState], | |
| board_size: int = 19, | |
| ): | |
| """ | |
| Initialize training orchestrator. | |
| Args: | |
| config: System configuration | |
| initial_state_fn: Function that returns initial game state | |
| board_size: Board/grid size for spatial games | |
| """ | |
| self.config = config | |
| self.initial_state_fn = initial_state_fn | |
| self.board_size = board_size | |
| # Setup device | |
| self.device = config.device | |
| torch.manual_seed(config.seed) | |
| # Initialize performance monitor | |
| self.monitor = PerformanceMonitor( | |
| window_size=100, | |
| enable_gpu_monitoring=(self.device != "cpu"), | |
| ) | |
| # Initialize components | |
| self._initialize_components() | |
| # Training state | |
| self.current_iteration = 0 | |
| self.best_win_rate = 0.0 | |
| self.best_model_path = None | |
| # Setup paths | |
| self._setup_paths() | |
| # Setup experiment tracking | |
| if config.use_wandb: | |
| self._setup_wandb() | |
| def _initialize_components(self): | |
| """Initialize all framework components.""" | |
| print("Initializing components...") | |
| # Policy-Value Network | |
| self.policy_value_net = create_policy_value_network( | |
| config=self.config.neural_net, | |
| board_size=self.board_size, | |
| device=self.device, | |
| ) | |
| print(f" ✓ Policy-Value Network: {self.policy_value_net.get_parameter_count():,} parameters") | |
| # HRM Agent | |
| self.hrm_agent = create_hrm_agent(self.config.hrm, self.device) | |
| print(f" ✓ HRM Agent: {self.hrm_agent.get_parameter_count():,} parameters") | |
| # TRM Agent | |
| self.trm_agent = create_trm_agent( | |
| self.config.trm, output_dim=self.config.neural_net.action_size, device=self.device | |
| ) | |
| print(f" ✓ TRM Agent: {self.trm_agent.get_parameter_count():,} parameters") | |
| # Neural MCTS | |
| self.mcts = NeuralMCTS( | |
| policy_value_network=self.policy_value_net, | |
| config=self.config.mcts, | |
| device=self.device, | |
| ) | |
| print(" ✓ Neural MCTS initialized") | |
| # Self-play collector | |
| self.self_play_collector = SelfPlayCollector(mcts=self.mcts, config=self.config.mcts) | |
| # Optimizers | |
| self._setup_optimizers() | |
| # Loss functions | |
| self.pv_loss_fn = AlphaZeroLoss(value_loss_weight=1.0) | |
| self.hrm_loss_fn = HRMLoss(ponder_weight=0.01) | |
| self.trm_loss_fn = TRMLoss( | |
| task_loss_fn=nn.MSELoss(), | |
| supervision_weight_decay=self.config.trm.supervision_weight_decay, | |
| ) | |
| # Replay buffer | |
| self.replay_buffer = PrioritizedReplayBuffer( | |
| capacity=self.config.training.buffer_size, | |
| alpha=0.6, | |
| beta_start=0.4, | |
| beta_frames=self.config.training.games_per_iteration * 10, | |
| ) | |
| # Mixed precision scaler | |
| self.scaler = GradScaler() if self.config.use_mixed_precision else None | |
| def _setup_optimizers(self): | |
| """Setup optimizers and learning rate schedulers.""" | |
| # Policy-Value optimizer | |
| self.pv_optimizer = torch.optim.SGD( | |
| self.policy_value_net.parameters(), | |
| lr=self.config.training.learning_rate, | |
| momentum=self.config.training.momentum, | |
| weight_decay=self.config.training.weight_decay, | |
| ) | |
| # HRM optimizer | |
| self.hrm_optimizer = torch.optim.Adam(self.hrm_agent.parameters(), lr=1e-3) | |
| # TRM optimizer | |
| self.trm_optimizer = torch.optim.Adam(self.trm_agent.parameters(), lr=1e-3) | |
| # Learning rate scheduler for policy-value network | |
| if self.config.training.lr_schedule == "cosine": | |
| self.pv_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.pv_optimizer, T_max=100) | |
| elif self.config.training.lr_schedule == "step": | |
| self.pv_scheduler = torch.optim.lr_scheduler.StepLR( | |
| self.pv_optimizer, | |
| step_size=self.config.training.lr_decay_steps, | |
| gamma=self.config.training.lr_decay_gamma, | |
| ) | |
| else: | |
| self.pv_scheduler = None | |
| def _setup_paths(self): | |
| """Setup directory paths.""" | |
| self.checkpoint_dir = Path(self.config.checkpoint_dir) | |
| self.checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| self.data_dir = Path(self.config.data_dir) | |
| self.data_dir.mkdir(parents=True, exist_ok=True) | |
| self.log_dir = Path(self.config.log_dir) | |
| self.log_dir.mkdir(parents=True, exist_ok=True) | |
| def _setup_wandb(self): | |
| """Setup Weights & Biases experiment tracking.""" | |
| try: | |
| import wandb | |
| wandb.init( | |
| project=self.config.wandb_project, | |
| entity=self.config.wandb_entity, | |
| config=self.config.to_dict(), | |
| name=f"run_{time.strftime('%Y%m%d_%H%M%S')}", | |
| ) | |
| print(" ✓ Weights & Biases initialized") | |
| except ImportError: | |
| print(" ⚠️ wandb not installed, skipping") | |
| self.config.use_wandb = False | |
| async def train_iteration(self, iteration: int) -> dict[str, Any]: | |
| """ | |
| Execute single training iteration. | |
| Args: | |
| iteration: Current iteration number | |
| Returns: | |
| Dictionary of metrics | |
| """ | |
| print(f"\n{'=' * 80}") | |
| print(f"Training Iteration {iteration}") | |
| print(f"{'=' * 80}") | |
| metrics = {} | |
| # Phase 1: Self-play data generation | |
| print("\n[1/5] Generating self-play data...") | |
| with TimingContext(self.monitor, "self_play_generation"): | |
| game_data = await self._generate_self_play_data() | |
| metrics["games_generated"] = len(game_data) | |
| print(f" Generated {len(game_data)} training examples") | |
| # Phase 2: Policy-Value network training | |
| print("\n[2/5] Training Policy-Value Network...") | |
| with TimingContext(self.monitor, "pv_training"): | |
| pv_metrics = await self._train_policy_value_network() | |
| metrics.update(pv_metrics) | |
| # Phase 3: HRM agent training (optional, if using HRM) | |
| if hasattr(self, "hrm_agent"): | |
| print("\n[3/5] Training HRM Agent...") | |
| with TimingContext(self.monitor, "hrm_training"): | |
| hrm_metrics = await self._train_hrm_agent() | |
| metrics.update(hrm_metrics) | |
| # Phase 4: TRM agent training (optional, if using TRM) | |
| if hasattr(self, "trm_agent"): | |
| print("\n[4/5] Training TRM Agent...") | |
| with TimingContext(self.monitor, "trm_training"): | |
| trm_metrics = await self._train_trm_agent() | |
| metrics.update(trm_metrics) | |
| # Phase 5: Evaluation | |
| print("\n[5/5] Evaluation...") | |
| if iteration % self.config.training.checkpoint_interval == 0: | |
| eval_metrics = await self._evaluate() | |
| metrics.update(eval_metrics) | |
| # Save checkpoint if improved | |
| if eval_metrics.get("win_rate", 0) > self.best_win_rate: | |
| self.best_win_rate = eval_metrics["win_rate"] | |
| self._save_checkpoint(iteration, metrics, is_best=True) | |
| print(f" ✓ New best model! Win rate: {self.best_win_rate:.2%}") | |
| # Log metrics | |
| self._log_metrics(iteration, metrics) | |
| # Performance check | |
| self.monitor.alert_if_slow() | |
| return metrics | |
| async def _generate_self_play_data(self) -> list[Experience]: | |
| """Generate training data from self-play games.""" | |
| num_games = self.config.training.games_per_iteration | |
| # In production, this would use parallel actors | |
| # For simplicity, we'll do sequential self-play | |
| all_examples = [] | |
| for game_idx in range(num_games): | |
| examples = await self.self_play_collector.play_game( | |
| initial_state=self.initial_state_fn(), | |
| temperature_threshold=self.config.mcts.temperature_threshold, | |
| ) | |
| # Convert to Experience objects | |
| for ex in examples: | |
| all_examples.append(Experience(state=ex.state, policy=ex.policy_target, value=ex.value_target)) | |
| if (game_idx + 1) % 5 == 0: | |
| print(f" Generated {game_idx + 1}/{num_games} games...") | |
| # Add to replay buffer | |
| self.replay_buffer.add_batch(all_examples) | |
| return all_examples | |
| async def _train_policy_value_network(self) -> dict[str, float]: | |
| """Train policy-value network on replay buffer data.""" | |
| if not self.replay_buffer.is_ready(self.config.training.batch_size): | |
| print(" Replay buffer not ready, skipping...") | |
| return {"policy_loss": 0.0, "value_loss": 0.0} | |
| self.policy_value_net.train() | |
| total_policy_loss = 0.0 | |
| total_value_loss = 0.0 | |
| num_batches = 10 # Train for 10 batches per iteration | |
| for _ in range(num_batches): | |
| # Sample batch | |
| experiences, indices, weights = self.replay_buffer.sample(self.config.training.batch_size) | |
| states, policies, values = collate_experiences(experiences) | |
| states = states.to(self.device) | |
| policies = policies.to(self.device) | |
| values = values.to(self.device) | |
| weights = torch.from_numpy(weights).to(self.device) | |
| # Forward pass | |
| if self.config.use_mixed_precision and self.scaler: | |
| with autocast(): | |
| policy_logits, value_pred = self.policy_value_net(states) | |
| loss, loss_dict = self.pv_loss_fn(policy_logits, value_pred, policies, values) | |
| # Apply importance sampling weights | |
| loss = (loss * weights).mean() | |
| # Backward pass with mixed precision | |
| self.pv_optimizer.zero_grad() | |
| self.scaler.scale(loss).backward() | |
| self.scaler.step(self.pv_optimizer) | |
| self.scaler.update() | |
| else: | |
| policy_logits, value_pred = self.policy_value_net(states) | |
| loss, loss_dict = self.pv_loss_fn(policy_logits, value_pred, policies, values) | |
| loss = (loss * weights).mean() | |
| self.pv_optimizer.zero_grad() | |
| loss.backward() | |
| self.pv_optimizer.step() | |
| # Update priorities in replay buffer | |
| with torch.no_grad(): | |
| td_errors = torch.abs(value_pred.squeeze() - values) | |
| self.replay_buffer.update_priorities(indices, td_errors.cpu().numpy()) | |
| total_policy_loss += loss_dict["policy"] | |
| total_value_loss += loss_dict["value"] | |
| # Log losses | |
| self.monitor.log_loss(loss_dict["policy"], loss_dict["value"], loss_dict["total"]) | |
| # Step learning rate scheduler | |
| if self.pv_scheduler: | |
| self.pv_scheduler.step() | |
| avg_policy_loss = total_policy_loss / num_batches | |
| avg_value_loss = total_value_loss / num_batches | |
| print(f" Policy Loss: {avg_policy_loss:.4f}, Value Loss: {avg_value_loss:.4f}") | |
| return {"policy_loss": avg_policy_loss, "value_loss": avg_value_loss} | |
| async def _train_hrm_agent(self) -> dict[str, float]: | |
| """Train HRM agent (placeholder for domain-specific implementation).""" | |
| # This would require domain-specific data and tasks | |
| # For now, return dummy metrics | |
| return {"hrm_halt_step": 5.0, "hrm_ponder_cost": 0.1} | |
| async def _train_trm_agent(self) -> dict[str, float]: | |
| """Train TRM agent (placeholder for domain-specific implementation).""" | |
| # This would require domain-specific data and tasks | |
| # For now, return dummy metrics | |
| return {"trm_convergence_step": 8.0, "trm_final_residual": 0.01} | |
| async def _evaluate(self) -> dict[str, float]: | |
| """Evaluate current model against baseline.""" | |
| # Simplified evaluation: play games against previous best | |
| # In production, this would be more sophisticated | |
| win_rate = 0.55 # Placeholder | |
| return { | |
| "win_rate": win_rate, | |
| "eval_games": self.config.training.evaluation_games, | |
| } | |
| def _save_checkpoint(self, iteration: int, metrics: dict, is_best: bool = False): | |
| """Save model checkpoint.""" | |
| checkpoint = { | |
| "iteration": iteration, | |
| "policy_value_net": self.policy_value_net.state_dict(), | |
| "hrm_agent": self.hrm_agent.state_dict(), | |
| "trm_agent": self.trm_agent.state_dict(), | |
| "pv_optimizer": self.pv_optimizer.state_dict(), | |
| "hrm_optimizer": self.hrm_optimizer.state_dict(), | |
| "trm_optimizer": self.trm_optimizer.state_dict(), | |
| "config": self.config.to_dict(), | |
| "metrics": metrics, | |
| "best_win_rate": self.best_win_rate, | |
| } | |
| # Save regular checkpoint | |
| path = self.checkpoint_dir / f"checkpoint_iter_{iteration}.pt" | |
| torch.save(checkpoint, path) | |
| print(f" ✓ Checkpoint saved: {path}") | |
| # Save best model | |
| if is_best: | |
| best_path = self.checkpoint_dir / "best_model.pt" | |
| torch.save(checkpoint, best_path) | |
| self.best_model_path = best_path | |
| print(f" ✓ Best model saved: {best_path}") | |
| def _log_metrics(self, iteration: int, metrics: dict): | |
| """Log metrics to console and tracking systems.""" | |
| print(f"\n[Metrics Summary - Iteration {iteration}]") | |
| for key, value in metrics.items(): | |
| if isinstance(value, float): | |
| print(f" {key}: {value:.4f}") | |
| else: | |
| print(f" {key}: {value}") | |
| # Log to wandb | |
| if self.config.use_wandb: | |
| try: | |
| import wandb | |
| wandb_metrics = self.monitor.export_to_wandb(iteration) | |
| wandb_metrics.update(metrics) | |
| wandb.log(wandb_metrics, step=iteration) | |
| except Exception as e: | |
| print(f" ⚠️ Failed to log to wandb: {e}") | |
| async def train(self, num_iterations: int): | |
| """ | |
| Run complete training loop. | |
| Args: | |
| num_iterations: Number of training iterations | |
| """ | |
| print("\n" + "=" * 80) | |
| print("Starting Training") | |
| print("=" * 80) | |
| print(f"Total iterations: {num_iterations}") | |
| print(f"Device: {self.device}") | |
| print(f"Mixed precision: {self.config.use_mixed_precision}") | |
| start_time = time.time() | |
| for iteration in range(1, num_iterations + 1): | |
| self.current_iteration = iteration | |
| try: | |
| _ = await self.train_iteration(iteration) | |
| # Check early stopping | |
| if self._should_early_stop(iteration): | |
| print("\n⚠️ Early stopping triggered") | |
| break | |
| except KeyboardInterrupt: | |
| print("\n⚠️ Training interrupted by user") | |
| break | |
| except Exception as e: | |
| print(f"\n❌ Error in iteration {iteration}: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| break | |
| elapsed = time.time() - start_time | |
| print(f"\n{'=' * 80}") | |
| print(f"Training completed in {elapsed / 3600:.2f} hours") | |
| print(f"Best win rate: {self.best_win_rate:.2%}") | |
| print(f"{'=' * 80}\n") | |
| # Print final performance summary | |
| self.monitor.print_summary() | |
| def _should_early_stop(self, iteration: int) -> bool: | |
| """Check early stopping criteria.""" | |
| # Placeholder: implement actual early stopping logic | |
| _ = iteration # noqa: F841 | |
| return False | |
| def load_checkpoint(self, path: str): | |
| """Load checkpoint from file.""" | |
| checkpoint = torch.load(path, map_location=self.device, weights_only=True) | |
| self.policy_value_net.load_state_dict(checkpoint["policy_value_net"]) | |
| self.hrm_agent.load_state_dict(checkpoint["hrm_agent"]) | |
| self.trm_agent.load_state_dict(checkpoint["trm_agent"]) | |
| self.pv_optimizer.load_state_dict(checkpoint["pv_optimizer"]) | |
| self.hrm_optimizer.load_state_dict(checkpoint["hrm_optimizer"]) | |
| self.trm_optimizer.load_state_dict(checkpoint["trm_optimizer"]) | |
| self.current_iteration = checkpoint["iteration"] | |
| self.best_win_rate = checkpoint.get("best_win_rate", 0.0) | |
| print(f"✓ Loaded checkpoint from iteration {self.current_iteration}") | |