langgraph-mcts-demo / demo_src /wandb_tracker.py
ianshank
feat: add personality output and bug fixes
40ee6b4
"""
Weights & Biases integration for experiment tracking.
"""
import os
from datetime import datetime
from typing import Any
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
wandb = None
class WandBTracker:
"""Weights & Biases experiment tracker for multi-agent MCTS demo."""
def __init__(self, project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True):
"""Initialize W&B tracker.
Args:
project_name: W&B project name
entity: W&B entity (username or team)
enabled: Whether tracking is enabled
"""
self.project_name = project_name
self.entity = entity
self.enabled = enabled and WANDB_AVAILABLE
self.run = None
self.run_id = None
def is_available(self) -> bool:
"""Check if W&B is available."""
return WANDB_AVAILABLE
def init_run(
self, run_name: str | None = None, config: dict[str, Any] | None = None, tags: list[str] | None = None
) -> bool:
"""Initialize a new W&B run.
Args:
run_name: Optional name for the run
config: Configuration dictionary to log
tags: Tags for the run
Returns:
True if run initialized successfully, False otherwise
"""
if not self.enabled:
return False
try:
# Generate run name if not provided
if run_name is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_name = f"mcts_query_{timestamp}"
# Default tags
if tags is None:
tags = ["demo", "multi-agent", "mcts"]
# Initialize run
self.run = wandb.init(
project=self.project_name,
entity=self.entity,
name=run_name,
config=config or {},
tags=tags,
reinit=True,
)
self.run_id = self.run.id
return True
except Exception as e:
print(f"W&B init error: {e}")
self.enabled = False
return False
def log_query_config(self, config: dict[str, Any]):
"""Log query configuration.
Args:
config: Configuration dictionary with agent settings, MCTS params, etc.
"""
if not self.enabled or not self.run:
return
try:
wandb.config.update(config)
except Exception as e:
print(f"W&B config log error: {e}")
def log_agent_result(
self,
agent_name: str,
response: str,
confidence: float,
execution_time_ms: float,
reasoning_steps: list[str] | None = None,
):
"""Log individual agent results.
Args:
agent_name: Name of the agent (HRM, TRM, MCTS)
response: Agent's response text
confidence: Confidence score (0-1)
execution_time_ms: Execution time in milliseconds
reasoning_steps: Optional list of reasoning steps
"""
if not self.enabled or not self.run:
return
try:
metrics = {
f"{agent_name}/confidence": confidence,
f"{agent_name}/execution_time_ms": execution_time_ms,
f"{agent_name}/response_length": len(response),
}
if reasoning_steps:
metrics[f"{agent_name}/num_reasoning_steps"] = len(reasoning_steps)
wandb.log(metrics)
# Log response as text
wandb.log({f"{agent_name}/response": wandb.Html(f"<pre>{response}</pre>")})
except Exception as e:
print(f"W&B agent result log error: {e}")
def log_mcts_result(self, mcts_result: dict[str, Any]):
"""Log MCTS-specific metrics.
Args:
mcts_result: Dictionary containing MCTS search results
"""
if not self.enabled or not self.run:
return
try:
# Extract key metrics
metrics = {
"mcts/best_value": mcts_result.get("best_value", 0),
"mcts/root_visits": mcts_result.get("root_visits", 0),
"mcts/total_nodes": mcts_result.get("total_nodes", 0),
"mcts/max_depth": mcts_result.get("max_depth_reached", 0),
"mcts/iterations": mcts_result.get("iterations_completed", 0),
"mcts/exploration_weight": mcts_result.get("exploration_weight", 1.414),
}
wandb.log(metrics)
# Log top actions as table
if "top_actions" in mcts_result:
top_actions_data = []
for action in mcts_result["top_actions"]:
top_actions_data.append(
[
action.get("action", ""),
action.get("visits", 0),
action.get("value", 0),
action.get("ucb1", 0),
]
)
if top_actions_data:
table = wandb.Table(data=top_actions_data, columns=["Action", "Visits", "Value", "UCB1"])
wandb.log({"mcts/top_actions_table": table})
# Log tree visualization as text artifact
if "tree_visualization" in mcts_result:
wandb.log({"mcts/tree_visualization": wandb.Html(f"<pre>{mcts_result['tree_visualization']}</pre>")})
except Exception as e:
print(f"W&B MCTS result log error: {e}")
def log_consensus(self, consensus_score: float, agents_used: list[str], final_response: str):
"""Log consensus metrics.
Args:
consensus_score: Agreement score between agents (0-1)
agents_used: List of agent names that were used
final_response: Final synthesized response
"""
if not self.enabled or not self.run:
return
try:
wandb.log(
{
"consensus/score": consensus_score,
"consensus/num_agents": len(agents_used),
"consensus/agents": ", ".join(agents_used),
"consensus/final_response_length": len(final_response),
}
)
# Categorize consensus level
if consensus_score > 0.7:
consensus_level = "high"
elif consensus_score > 0.4:
consensus_level = "medium"
else:
consensus_level = "low"
wandb.log({"consensus/level": consensus_level})
except Exception as e:
print(f"W&B consensus log error: {e}")
def log_performance(self, total_time_ms: float):
"""Log overall performance metrics.
Args:
total_time_ms: Total execution time in milliseconds
"""
if not self.enabled or not self.run:
return
try:
wandb.log({"performance/total_time_ms": total_time_ms, "performance/total_time_s": total_time_ms / 1000})
except Exception as e:
print(f"W&B performance log error: {e}")
def log_full_result(self, result: dict[str, Any]):
"""Log the complete result as an artifact.
Args:
result: Full framework result dictionary
"""
if not self.enabled or not self.run:
return
try:
# Create artifact
artifact = wandb.Artifact(name=f"query_result_{self.run_id}", type="result")
# Add result as JSON
import json
import tempfile
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
json.dump(result, f, indent=2, default=str)
temp_path = f.name
artifact.add_file(temp_path, name="result.json")
wandb.log_artifact(artifact)
# Clean up temp file
os.unlink(temp_path)
except Exception as e:
print(f"W&B full result log error: {e}")
def log_query_summary(
self, query: str, use_hrm: bool, use_trm: bool, use_mcts: bool, consensus_score: float, total_time_ms: float
):
"""Log a summary row for the query.
Args:
query: The input query
use_hrm: Whether HRM was enabled
use_trm: Whether TRM was enabled
use_mcts: Whether MCTS was enabled
consensus_score: Final consensus score
total_time_ms: Total execution time
"""
if not self.enabled or not self.run:
return
try:
# Create summary table entry
summary_data = [
[
query[:100] + "..." if len(query) > 100 else query,
"βœ“" if use_hrm else "βœ—",
"βœ“" if use_trm else "βœ—",
"βœ“" if use_mcts else "βœ—",
f"{consensus_score:.1%}",
f"{total_time_ms:.2f}",
]
]
table = wandb.Table(data=summary_data, columns=["Query", "HRM", "TRM", "MCTS", "Consensus", "Time (ms)"])
wandb.log({"query_summary": table})
except Exception as e:
print(f"W&B summary log error: {e}")
def finish_run(self):
"""Finish the current W&B run."""
if not self.enabled or not self.run:
return
try:
wandb.finish()
self.run = None
self.run_id = None
except Exception as e:
print(f"W&B finish error: {e}")
def get_run_url(self) -> str | None:
"""Get the URL for the current run.
Returns:
URL string or None if no active run
"""
if not self.enabled or not self.run:
return None
try:
return self.run.get_url()
except Exception:
return None
# Global tracker instance
_global_tracker: WandBTracker | None = None
def get_tracker(
project_name: str = "langgraph-mcts-demo", entity: str | None = None, enabled: bool = True
) -> WandBTracker:
"""Get or create the global W&B tracker.
Args:
project_name: W&B project name
entity: W&B entity
enabled: Whether tracking is enabled
Returns:
WandBTracker instance
"""
global _global_tracker
if _global_tracker is None:
_global_tracker = WandBTracker(project_name=project_name, entity=entity, enabled=enabled)
return _global_tracker
def is_wandb_available() -> bool:
"""Check if W&B is available."""
return WANDB_AVAILABLE