Spaces:
Running
Running
| """ | |
| Debug utilities for multi-agent MCTS framework. | |
| Provides: | |
| - MCTS tree visualization (text-based) | |
| - Step-by-step MCTS execution logging when LOG_LEVEL=DEBUG | |
| - UCB score logging at each selection | |
| - State diff tracking between iterations | |
| - Export tree to DOT format for graphviz | |
| """ | |
| import logging | |
| import os | |
| from typing import Any | |
| from .logging import get_logger | |
| class MCTSDebugger: | |
| """ | |
| Comprehensive debugger for MCTS operations. | |
| Provides detailed step-by-step logging, tree visualization, | |
| and state tracking for MCTS execution. | |
| """ | |
| def __init__(self, session_id: str = "default", enabled: bool | None = None): | |
| """ | |
| Initialize MCTS debugger. | |
| Args: | |
| session_id: Unique identifier for debug session | |
| enabled: Enable debugging (defaults to LOG_LEVEL=DEBUG) | |
| """ | |
| self.session_id = session_id | |
| self.logger = get_logger("observability.debug") | |
| # Auto-enable if LOG_LEVEL is DEBUG | |
| if enabled is None: | |
| log_level = os.environ.get("LOG_LEVEL", "INFO").upper() | |
| self.enabled = log_level == "DEBUG" | |
| else: | |
| self.enabled = enabled | |
| # State tracking | |
| self._iteration_count = 0 | |
| self._state_history: list[dict[str, Any]] = [] | |
| self._selection_history: list[dict[str, Any]] = [] | |
| self._ucb_history: list[dict[str, float]] = [] | |
| def log_iteration_start(self, iteration: int) -> None: | |
| """Log the start of an MCTS iteration.""" | |
| if not self.enabled: | |
| return | |
| self._iteration_count = iteration | |
| self.logger.debug( | |
| f"=== MCTS Iteration {iteration} START ===", | |
| extra={ | |
| "debug_event": "iteration_start", | |
| "mcts_iteration": iteration, | |
| "session_id": self.session_id, | |
| }, | |
| ) | |
| def log_selection( | |
| self, | |
| node_id: str, | |
| ucb_score: float, | |
| visits: int, | |
| value: float, | |
| depth: int, | |
| children_count: int, | |
| is_selected: bool = False, | |
| ) -> None: | |
| """Log UCB score and selection decision for a node.""" | |
| if not self.enabled: | |
| return | |
| selection_data = { | |
| "node_id": node_id, | |
| "ucb_score": round(ucb_score, 6), | |
| "visits": visits, | |
| "value": round(value, 6), | |
| "avg_value": round(value / max(visits, 1), 6), | |
| "depth": depth, | |
| "children_count": children_count, | |
| "is_selected": is_selected, | |
| } | |
| self._selection_history.append(selection_data) | |
| log_msg = f"Selection: node={node_id} UCB={ucb_score:.4f} visits={visits} value={value:.4f} depth={depth}" | |
| if is_selected: | |
| log_msg += " [SELECTED]" | |
| self.logger.debug( | |
| log_msg, | |
| extra={ | |
| "debug_event": "mcts_selection", | |
| "selection": selection_data, | |
| "session_id": self.session_id, | |
| "mcts_iteration": self._iteration_count, | |
| }, | |
| ) | |
| def log_ucb_comparison( | |
| self, | |
| parent_id: str, | |
| children_ucb: dict[str, float], | |
| selected_child: str, | |
| ) -> None: | |
| """Log UCB score comparison for all children of a node.""" | |
| if not self.enabled: | |
| return | |
| self._ucb_history.append(children_ucb) | |
| ucb_summary = ", ".join( | |
| [ | |
| f"{cid}={score:.4f}{'*' if cid == selected_child else ''}" | |
| for cid, score in sorted(children_ucb.items(), key=lambda x: x[1], reverse=True) | |
| ] | |
| ) | |
| self.logger.debug( | |
| f"UCB Comparison at {parent_id}: {ucb_summary}", | |
| extra={ | |
| "debug_event": "ucb_comparison", | |
| "parent_id": parent_id, | |
| "children_ucb": {k: round(v, 6) for k, v in children_ucb.items()}, | |
| "selected_child": selected_child, | |
| "session_id": self.session_id, | |
| "mcts_iteration": self._iteration_count, | |
| }, | |
| ) | |
| def log_expansion( | |
| self, | |
| parent_id: str, | |
| action: str, | |
| new_node_id: str, | |
| available_actions: list[str], | |
| ) -> None: | |
| """Log node expansion details.""" | |
| if not self.enabled: | |
| return | |
| self.logger.debug( | |
| f"Expansion: parent={parent_id} action={action} new_node={new_node_id} " | |
| f"available={len(available_actions)} actions", | |
| extra={ | |
| "debug_event": "mcts_expansion", | |
| "parent_id": parent_id, | |
| "action": action, | |
| "new_node_id": new_node_id, | |
| "available_actions": available_actions, | |
| "session_id": self.session_id, | |
| "mcts_iteration": self._iteration_count, | |
| }, | |
| ) | |
| def log_simulation( | |
| self, | |
| node_id: str, | |
| simulation_result: float, | |
| simulation_details: dict[str, Any] | None = None, | |
| ) -> None: | |
| """Log simulation/rollout results.""" | |
| if not self.enabled: | |
| return | |
| self.logger.debug( | |
| f"Simulation: node={node_id} result={simulation_result:.4f}", | |
| extra={ | |
| "debug_event": "mcts_simulation", | |
| "node_id": node_id, | |
| "simulation_result": round(simulation_result, 6), | |
| "simulation_details": simulation_details or {}, | |
| "session_id": self.session_id, | |
| "mcts_iteration": self._iteration_count, | |
| }, | |
| ) | |
| def log_backpropagation( | |
| self, | |
| path: list[str], | |
| value: float, | |
| updates: list[dict[str, Any]], | |
| ) -> None: | |
| """Log backpropagation path and value updates.""" | |
| if not self.enabled: | |
| return | |
| self.logger.debug( | |
| f"Backprop: path={' -> '.join(path)} value={value:.4f}", | |
| extra={ | |
| "debug_event": "mcts_backprop", | |
| "path": path, | |
| "value": round(value, 6), | |
| "updates": updates, | |
| "session_id": self.session_id, | |
| "mcts_iteration": self._iteration_count, | |
| }, | |
| ) | |
| def log_iteration_end( | |
| self, | |
| iteration: int, | |
| best_action: str, | |
| best_ucb: float, | |
| tree_size: int, | |
| ) -> None: | |
| """Log the end of an MCTS iteration.""" | |
| if not self.enabled: | |
| return | |
| self.logger.debug( | |
| f"=== MCTS Iteration {iteration} END === " | |
| f"best_action={best_action} UCB={best_ucb:.4f} tree_size={tree_size}", | |
| extra={ | |
| "debug_event": "iteration_end", | |
| "mcts_iteration": iteration, | |
| "best_action": best_action, | |
| "best_ucb": round(best_ucb, 6), | |
| "tree_size": tree_size, | |
| "session_id": self.session_id, | |
| }, | |
| ) | |
| def log_state_diff( | |
| self, | |
| old_state: dict[str, Any], | |
| new_state: dict[str, Any], | |
| description: str = "State change", | |
| ) -> None: | |
| """Log differences between two states.""" | |
| if not self.enabled: | |
| return | |
| diff = self._compute_state_diff(old_state, new_state) | |
| if diff: | |
| self._state_history.append( | |
| { | |
| "iteration": self._iteration_count, | |
| "description": description, | |
| "diff": diff, | |
| } | |
| ) | |
| self.logger.debug( | |
| f"State diff: {description}", | |
| extra={ | |
| "debug_event": "state_diff", | |
| "description": description, | |
| "diff": diff, | |
| "session_id": self.session_id, | |
| "mcts_iteration": self._iteration_count, | |
| }, | |
| ) | |
| def _compute_state_diff( | |
| self, | |
| old: dict[str, Any], | |
| new: dict[str, Any], | |
| prefix: str = "", | |
| ) -> dict[str, Any]: | |
| """Compute differences between two dictionaries.""" | |
| diff = {} | |
| all_keys = set(old.keys()) | set(new.keys()) | |
| for key in all_keys: | |
| full_key = f"{prefix}.{key}" if prefix else key | |
| if key not in old: | |
| diff[full_key] = {"added": new[key]} | |
| elif key not in new: | |
| diff[full_key] = {"removed": old[key]} | |
| elif old[key] != new[key]: | |
| if isinstance(old[key], dict) and isinstance(new[key], dict): | |
| nested_diff = self._compute_state_diff(old[key], new[key], full_key) | |
| diff.update(nested_diff) | |
| else: | |
| diff[full_key] = {"old": old[key], "new": new[key]} | |
| return diff | |
| def get_debug_summary(self) -> dict[str, Any]: | |
| """Get summary of debug information collected.""" | |
| return { | |
| "session_id": self.session_id, | |
| "total_iterations": self._iteration_count, | |
| "selection_history_count": len(self._selection_history), | |
| "state_changes_count": len(self._state_history), | |
| "ucb_comparisons_count": len(self._ucb_history), | |
| } | |
| def visualize_mcts_tree( | |
| root_node: Any, | |
| max_depth: int = 10, | |
| max_children: int = 5, | |
| show_ucb: bool = True, | |
| indent: str = " ", | |
| ) -> str: | |
| """ | |
| Generate text-based visualization of MCTS tree. | |
| Args: | |
| root_node: Root MCTSNode | |
| max_depth: Maximum depth to visualize | |
| max_children: Maximum children to show per node | |
| show_ucb: Show UCB scores | |
| indent: Indentation string | |
| Returns: | |
| Text representation of the tree | |
| """ | |
| lines = ["MCTS Tree Visualization", "=" * 40] | |
| def render_node(node: Any, depth: int = 0, prefix: str = "") -> None: | |
| if depth > max_depth: | |
| lines.append(f"{prefix}{indent}... (max depth reached)") | |
| return | |
| # Node info | |
| visits = getattr(node, "visits", 0) | |
| value = getattr(node, "value", 0.0) | |
| action = getattr(node, "action", "root") | |
| state_id = getattr(node, "state_id", "unknown") | |
| avg_value = value / max(visits, 1) | |
| node_info = f"[{state_id}] action={action} visits={visits} value={value:.3f} avg={avg_value:.3f}" | |
| if show_ucb and hasattr(node, "ucb1") and visits > 0: | |
| try: | |
| ucb = node.ucb1() | |
| if ucb != float("inf"): | |
| node_info += f" UCB={ucb:.3f}" | |
| except Exception: | |
| pass | |
| lines.append(f"{prefix}{node_info}") | |
| # Children | |
| children = getattr(node, "children", []) | |
| if children: | |
| # Sort by visits (most visited first) | |
| sorted_children = sorted(children, key=lambda c: getattr(c, "visits", 0), reverse=True) | |
| display_children = sorted_children[:max_children] | |
| for i, child in enumerate(display_children): | |
| is_last = i == len(display_children) - 1 | |
| child_prefix = prefix + indent + ("└── " if is_last else "├── ") | |
| next_prefix = prefix + indent + (" " if is_last else "│ ") | |
| lines.append(f"{child_prefix[:-4]}") | |
| render_node(child, depth + 1, next_prefix) | |
| if len(children) > max_children: | |
| lines.append(f"{prefix}{indent}... and {len(children) - max_children} more children") | |
| render_node(root_node) | |
| lines.append("=" * 40) | |
| return "\n".join(lines) | |
| def export_tree_to_dot( | |
| root_node: Any, | |
| filename: str = "mcts_tree.dot", | |
| max_depth: int = 10, | |
| include_ucb: bool = True, | |
| ) -> str: | |
| """ | |
| Export MCTS tree to DOT format for graphviz visualization. | |
| Args: | |
| root_node: Root MCTSNode | |
| filename: Output filename (optional) | |
| max_depth: Maximum depth to export | |
| include_ucb: Include UCB scores in labels | |
| Returns: | |
| DOT format string | |
| """ | |
| lines = [ | |
| "digraph MCTSTree {", | |
| ' graph [rankdir=TB, label="MCTS Tree", fontsize=16];', | |
| " node [shape=box, style=filled, fillcolor=lightblue];", | |
| " edge [fontsize=10];", | |
| "", | |
| ] | |
| node_counter = [0] # Use list for mutable counter in closure | |
| def add_node(node: Any, depth: int = 0, parent_dot_id: str | None = None) -> None: | |
| if depth > max_depth: | |
| return | |
| # Generate unique DOT ID | |
| dot_id = f"node_{node_counter[0]}" | |
| node_counter[0] += 1 | |
| # Node attributes | |
| visits = getattr(node, "visits", 0) | |
| value = getattr(node, "value", 0.0) | |
| action = getattr(node, "action", "root") | |
| state_id = getattr(node, "state_id", "unknown") | |
| avg_value = value / max(visits, 1) | |
| # Build label | |
| label_parts = [ | |
| f"ID: {state_id}", | |
| f"Action: {action}", | |
| f"Visits: {visits}", | |
| f"Value: {value:.3f}", | |
| f"Avg: {avg_value:.3f}", | |
| ] | |
| if include_ucb and hasattr(node, "ucb1") and visits > 0: | |
| try: | |
| ucb = node.ucb1() | |
| if ucb != float("inf"): | |
| label_parts.append(f"UCB: {ucb:.3f}") | |
| except Exception: | |
| pass | |
| label = "\\n".join(label_parts) | |
| # Color based on value | |
| if avg_value >= 0.7: | |
| color = "lightgreen" | |
| elif avg_value >= 0.4: | |
| color = "lightyellow" | |
| else: | |
| color = "lightcoral" | |
| lines.append(f' {dot_id} [label="{label}", fillcolor={color}];') | |
| # Edge from parent | |
| if parent_dot_id: | |
| lines.append(f' {parent_dot_id} -> {dot_id} [label="{action}"];') | |
| # Process children | |
| children = getattr(node, "children", []) | |
| for child in children: | |
| add_node(child, depth + 1, dot_id) | |
| add_node(root_node) | |
| lines.append("}") | |
| dot_content = "\n".join(lines) | |
| # Write to file if filename provided | |
| if filename: | |
| with open(filename, "w") as f: | |
| f.write(dot_content) | |
| return dot_content | |
| def print_debug_banner(message: str, char: str = "=", width: int = 60) -> None: | |
| """Print a debug banner message.""" | |
| logger = get_logger("observability.debug") | |
| border = char * width | |
| logger.debug(border) | |
| logger.debug(f"{message:^{width}}") | |
| logger.debug(border) | |
| def log_agent_state_snapshot( | |
| agent_name: str, | |
| state: dict[str, Any], | |
| include_keys: list[str] | None = None, | |
| ) -> None: | |
| """ | |
| Log a snapshot of agent state for debugging. | |
| Args: | |
| agent_name: Name of the agent | |
| state: Current state dictionary | |
| include_keys: Specific keys to include (None = all) | |
| """ | |
| logger = get_logger("observability.debug") | |
| filtered_state = {k: state.get(k) for k in include_keys if k in state} if include_keys else state | |
| logger.debug( | |
| f"Agent {agent_name} state snapshot", | |
| extra={ | |
| "debug_event": "agent_state_snapshot", | |
| "agent_name": agent_name, | |
| "state": filtered_state, | |
| }, | |
| ) | |
| def enable_verbose_debugging() -> None: | |
| """Enable verbose debugging by setting LOG_LEVEL to DEBUG.""" | |
| os.environ["LOG_LEVEL"] = "DEBUG" | |
| # Reconfigure root logger | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(logging.DEBUG) | |
| for handler in root_logger.handlers: | |
| handler.setLevel(logging.DEBUG) | |
| logger = get_logger("observability.debug") | |
| logger.info("Verbose debugging ENABLED") | |
| def disable_verbose_debugging() -> None: | |
| """Disable verbose debugging by setting LOG_LEVEL to INFO.""" | |
| os.environ["LOG_LEVEL"] = "INFO" | |
| root_logger = logging.getLogger() | |
| root_logger.setLevel(logging.INFO) | |
| for handler in root_logger.handlers: | |
| handler.setLevel(logging.INFO) | |
| logger = get_logger("observability.debug") | |
| logger.info("Verbose debugging DISABLED") | |