Spaces:
Sleeping
Sleeping
| """ | |
| Educational MCTS demonstration using the production framework. | |
| This demo uses the real MCTSEngine from src.framework.mcts.core to provide | |
| an authentic learning experience while remaining accessible for demonstrations. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import Any | |
| from src.framework.mcts.core import MCTSEngine, MCTSNode, MCTSState | |
| from src.framework.mcts.policies import RolloutPolicy, SelectionPolicy | |
| class DemoRolloutPolicy(RolloutPolicy): | |
| """ | |
| Educational rollout policy for demo purposes. | |
| Evaluates states based on: | |
| - Depth of exploration (deeper = more thorough) | |
| - Action quality (domain-specific heuristics) | |
| - Exploration randomness | |
| """ | |
| def __init__(self, category: str, action_templates: dict[str, list[str]]): | |
| """ | |
| Initialize demo rollout policy. | |
| Args: | |
| category: Query category for heuristic evaluation | |
| action_templates: Available action templates for scoring | |
| """ | |
| self.category = category | |
| self.action_templates = action_templates | |
| # Define key terms that indicate quality actions per category | |
| self.quality_indicators = { | |
| "architecture": ["scalability", "consistency", "requirements"], | |
| "optimization": ["profile", "caching", "parallel"], | |
| "database": ["patterns", "relationships", "scaling"], | |
| "distributed": ["circuit", "retry", "bulkhead"], | |
| "default": ["decompose", "constraints", "trade-offs"], | |
| } | |
| async def evaluate( | |
| self, | |
| state: MCTSState, | |
| rng, | |
| max_depth: int = 10, | |
| ) -> float: | |
| """ | |
| Evaluate a state through heuristic analysis. | |
| This combines: | |
| - Depth bonus: rewards thorough exploration | |
| - Action quality: rewards domain-appropriate actions | |
| - Noise: adds exploration randomness | |
| Args: | |
| state: State to evaluate | |
| rng: Random number generator | |
| max_depth: Maximum depth (unused in heuristic) | |
| Returns: | |
| Estimated value in [0, 1] range | |
| """ | |
| # Base value | |
| base_value = 0.5 | |
| # Depth bonus: deeper exploration = more value (up to 0.3) | |
| depth = state.features.get("depth", 0) | |
| depth_bonus = min(depth * 0.1, 0.3) | |
| # Action quality bonus | |
| action_bonus = 0.0 | |
| last_action = state.features.get("last_action", "") | |
| if last_action: | |
| # Check if action contains quality indicators for this category | |
| indicators = self.quality_indicators.get(self.category, self.quality_indicators["default"]) | |
| for term in indicators: | |
| if term in last_action.lower(): | |
| action_bonus = 0.15 | |
| break | |
| # Add exploration noise | |
| noise = rng.uniform(-0.1, 0.1) | |
| # Combine components | |
| value = base_value + depth_bonus + action_bonus + noise | |
| # Clamp to [0, 1] | |
| return max(0.0, min(1.0, value)) | |
| class MCTSDemo: | |
| """ | |
| Educational MCTS demonstration using the production framework. | |
| This class wraps the production MCTSEngine to provide: | |
| - Simple, educational interface for demos | |
| - Category-based action selection | |
| - Tree visualization for learning | |
| - Deterministic behavior with seeds | |
| Unlike the old mock implementation, this uses the real MCTS algorithm | |
| with all its features: UCB1 selection, progressive widening, caching, etc. | |
| """ | |
| def __init__(self, max_depth: int = 5): | |
| """ | |
| Initialize MCTS demo. | |
| Args: | |
| max_depth: Maximum tree depth for exploration | |
| """ | |
| self.max_depth = max_depth | |
| # Action templates for different query types | |
| # These provide domain-specific reasoning paths | |
| self.action_templates = { | |
| "architecture": [ | |
| "Consider microservices for scalability", | |
| "Evaluate monolith for simplicity", | |
| "Analyze team capabilities", | |
| "Assess deployment requirements", | |
| "Review data consistency needs", | |
| ], | |
| "optimization": [ | |
| "Profile application hotspots", | |
| "Implement caching layer", | |
| "Use parallel processing", | |
| "Optimize database queries", | |
| "Reduce memory allocations", | |
| ], | |
| "database": [ | |
| "Analyze query patterns", | |
| "Consider data relationships", | |
| "Evaluate consistency requirements", | |
| "Plan for horizontal scaling", | |
| "Assess read/write ratios", | |
| ], | |
| "distributed": [ | |
| "Implement circuit breakers", | |
| "Add retry mechanisms", | |
| "Use message queues", | |
| "Apply bulkhead pattern", | |
| "Design for eventual consistency", | |
| ], | |
| "default": [ | |
| "Decompose the problem", | |
| "Identify constraints", | |
| "Evaluate trade-offs", | |
| "Consider alternatives", | |
| "Validate assumptions", | |
| ], | |
| } | |
| def _categorize_query(self, query: str) -> str: | |
| """ | |
| Categorize query to select appropriate action templates. | |
| Args: | |
| query: User's input query | |
| Returns: | |
| Category name for action selection | |
| """ | |
| query_lower = query.lower() | |
| if "architecture" in query_lower or "microservice" in query_lower: | |
| return "architecture" | |
| elif "optim" in query_lower or "performance" in query_lower: | |
| return "optimization" | |
| elif "database" in query_lower or "sql" in query_lower: | |
| return "database" | |
| elif "distribut" in query_lower or "fault" in query_lower: | |
| return "distributed" | |
| return "default" | |
| def _create_action_generator(self, category: str): | |
| """ | |
| Create action generator function for this query category. | |
| Args: | |
| category: Query category | |
| Returns: | |
| Function that generates actions for a given state | |
| """ | |
| def action_generator(state: MCTSState) -> list[str]: | |
| """Generate available actions from current state.""" | |
| # Get category-specific actions | |
| actions = self.action_templates.get(category, self.action_templates["default"]) | |
| # Filter out already-used actions (track via state features) | |
| used_actions = state.features.get("used_actions", set()) | |
| available = [a for a in actions if a not in used_actions] | |
| # If all actions used, allow re-exploring top 2 | |
| if not available: | |
| return actions[:2] | |
| return available | |
| return action_generator | |
| def _create_state_transition(self, category: str): | |
| """ | |
| Create state transition function for this query category. | |
| Args: | |
| category: Query category | |
| Returns: | |
| Function that computes next state from current state + action | |
| """ | |
| def state_transition(state: MCTSState, action: str) -> MCTSState: | |
| """Compute next state by applying action.""" | |
| # Track action history | |
| action_history = list(state.features.get("action_history", [])) | |
| action_history.append(action) | |
| # Track used actions | |
| used_actions = set(state.features.get("used_actions", set())) | |
| used_actions.add(action) | |
| # Increment depth | |
| depth = state.features.get("depth", 0) + 1 | |
| # Create new state ID from action history | |
| state_id = " -> ".join(action_history) | |
| # Build new state | |
| new_state = MCTSState( | |
| state_id=state_id, | |
| features={ | |
| "action_history": action_history, | |
| "used_actions": used_actions, | |
| "depth": depth, | |
| "last_action": action, | |
| "category": category, | |
| }, | |
| ) | |
| return new_state | |
| return state_transition | |
| def _generate_tree_visualization(self, root: MCTSNode, max_nodes: int = 20) -> str: | |
| """ | |
| Generate ASCII visualization of the MCTS tree. | |
| This provides educational insight into the search process. | |
| Args: | |
| root: Root node of the tree | |
| max_nodes: Maximum nodes to display | |
| Returns: | |
| ASCII art representation of the tree | |
| """ | |
| max_nodes = max(1, max_nodes) | |
| lines = [] | |
| lines.append("MCTS Tree Visualization") | |
| lines.append("=" * 50) | |
| nodes_rendered = 0 | |
| def format_node(node: MCTSNode, prefix: str = "", is_last: bool = True) -> list[str]: | |
| nonlocal nodes_rendered | |
| result = [] | |
| # Node representation | |
| connector = "└── " if is_last else "├── " | |
| if nodes_rendered >= max_nodes: | |
| result.append(f"{prefix}{connector}... (truncated)") | |
| return result | |
| nodes_rendered += 1 | |
| # Display action or state | |
| node_str = f"{node.state.state_id[:30]}..." | |
| if node.action: | |
| node_str = f"{node.action[:25]}..." | |
| stats = f"[V:{node.visits}, Q:{node.value:.3f}]" | |
| result.append(f"{prefix}{connector}{node_str} {stats}") | |
| # Recursively add children | |
| new_prefix = prefix + (" " if is_last else "│ ") | |
| # Limit children shown | |
| children_to_show = node.children[:3] | |
| for i, child in enumerate(children_to_show): | |
| is_child_last = i == len(children_to_show) - 1 | |
| result.extend(format_node(child, new_prefix, is_child_last)) | |
| if len(node.children) > 3: | |
| result.append(f"{new_prefix} ... and {len(node.children) - 3} more") | |
| return result | |
| # Start with root | |
| lines.append(f"Root: {root.state.state_id[:40]}... [V:{root.visits}, Q:{root.value:.3f}]") | |
| nodes_rendered += 1 | |
| for i, child in enumerate(root.children[:5]): | |
| is_last = i == len(root.children[:5]) - 1 | |
| lines.extend(format_node(child, "", is_last)) | |
| if len(root.children) > 5: | |
| lines.append(f"... and {len(root.children) - 5} more branches") | |
| return "\n".join(lines) | |
| async def search( | |
| self, | |
| query: str, | |
| iterations: int = 25, | |
| exploration_weight: float = 1.414, | |
| seed: int | None = None, | |
| ) -> dict[str, Any]: | |
| """ | |
| Run MCTS search on the query using the production framework. | |
| This method demonstrates the full MCTS algorithm: | |
| 1. Selection: UCB1-based tree traversal | |
| 2. Expansion: Progressive widening of nodes | |
| 3. Simulation: Heuristic evaluation (rollout) | |
| 4. Backpropagation: Value updates up the tree | |
| Args: | |
| query: The input query to analyze | |
| iterations: Number of MCTS iterations (more = better but slower) | |
| exploration_weight: UCB1 exploration constant (higher = more exploration) | |
| seed: Random seed for deterministic results | |
| Returns: | |
| Dictionary with: | |
| - best_action: Recommended next step | |
| - best_value: Confidence in recommendation | |
| - statistics: Search metrics and performance data | |
| - tree_visualization: ASCII art of search tree | |
| """ | |
| # Determine query category | |
| category = self._categorize_query(query) | |
| # Initialize MCTS engine with production features | |
| engine = MCTSEngine( | |
| seed=seed if seed is not None else 42, | |
| exploration_weight=exploration_weight, | |
| progressive_widening_k=1.0, # Moderate expansion | |
| progressive_widening_alpha=0.5, | |
| max_parallel_rollouts=4, | |
| cache_size_limit=10000, | |
| ) | |
| # Create root state | |
| root_state = MCTSState( | |
| state_id=f"Query: {query[:50]}", | |
| features={ | |
| "query": query, | |
| "category": category, | |
| "action_history": [], | |
| "used_actions": set(), | |
| "depth": 0, | |
| "last_action": "", | |
| }, | |
| ) | |
| # Create root node | |
| root = MCTSNode(state=root_state, rng=engine.rng) | |
| # Create domain-specific functions | |
| action_generator = self._create_action_generator(category) | |
| state_transition = self._create_state_transition(category) | |
| rollout_policy = DemoRolloutPolicy(category, self.action_templates) | |
| # Run MCTS search with production engine | |
| best_action, stats = await engine.search( | |
| root=root, | |
| num_iterations=iterations, | |
| action_generator=action_generator, | |
| state_transition=state_transition, | |
| rollout_policy=rollout_policy, | |
| max_rollout_depth=self.max_depth, | |
| selection_policy=SelectionPolicy.MAX_VISITS, # Most robust | |
| ) | |
| # Extract best child info | |
| best_child = None | |
| if root.children: | |
| best_child = max(root.children, key=lambda c: c.visits) | |
| # Compile results for demo interface | |
| result = { | |
| "best_action": best_action or "No action found", | |
| "best_value": round(best_child.value, 4) if best_child else 0.0, | |
| "root_visits": root.visits, | |
| "total_nodes": engine.get_cached_node_count(), | |
| "max_depth_reached": engine.get_cached_tree_depth(), | |
| "iterations_completed": iterations, | |
| "exploration_weight": exploration_weight, | |
| "seed": seed, | |
| "category": category, | |
| # Top actions sorted by visits | |
| "top_actions": [ | |
| { | |
| "action": child.action, | |
| "visits": child.visits, | |
| "value": round(child.value, 4), | |
| "ucb1": round( | |
| child.visits / root.visits if root.visits > 0 else 0.0, 4 | |
| ), # Simplified UCB display | |
| } | |
| for child in sorted(root.children, key=lambda c: -c.visits)[:5] | |
| ], | |
| # Framework statistics | |
| "framework_stats": { | |
| "cache_hits": stats.get("cache_hits", 0), | |
| "cache_misses": stats.get("cache_misses", 0), | |
| "cache_hit_rate": round(stats.get("cache_hit_rate", 0.0), 4), | |
| "total_simulations": stats.get("total_simulations", 0), | |
| }, | |
| # Educational visualization | |
| "tree_visualization": self._generate_tree_visualization(root), | |
| } | |
| return result | |