File size: 19,970 Bytes
40ee6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
"""
LangGraph Multi-Agent MCTS Framework - Hugging Face Spaces Demo

A proof-of-concept demonstration of multi-agent reasoning with Monte Carlo Tree Search.
"""

import asyncio
import time
from dataclasses import dataclass

import gradio as gr
import numpy as np

# Demo-specific simplified implementations
from demo_src.agents_demo import HRMAgent, TRMAgent
from demo_src.llm_mock import HuggingFaceClient, MockLLMClient
from demo_src.mcts_demo import MCTSDemo
from demo_src.wandb_tracker import WandBTracker, is_wandb_available


@dataclass
class AgentResult:
    """Result from a single agent."""

    agent_name: str
    response: str
    confidence: float
    reasoning_steps: list[str]
    execution_time_ms: float


@dataclass
class FrameworkResult:
    """Combined result from all agents."""

    query: str
    hrm_result: AgentResult | None
    trm_result: AgentResult | None
    mcts_result: dict | None
    consensus_score: float
    final_response: str
    total_time_ms: float
    metadata: dict


class MultiAgentFrameworkDemo:
    """Simplified multi-agent framework for Hugging Face Spaces demo."""

    def __init__(self, use_hf_inference: bool = False, hf_model: str = ""):
        """Initialize the demo framework.

        Args:
            use_hf_inference: Use Hugging Face Inference API instead of mock
            hf_model: Hugging Face model ID for inference
        """
        self.use_hf_inference = use_hf_inference
        self.hf_model = hf_model

        # Initialize components
        if use_hf_inference and hf_model:
            self.llm_client = HuggingFaceClient(model_id=hf_model)
        else:
            self.llm_client = MockLLMClient()

        self.hrm_agent = HRMAgent(self.llm_client)
        self.trm_agent = TRMAgent(self.llm_client)
        self.mcts = MCTSDemo()

    async def process_query(
        self,
        query: str,
        use_hrm: bool = True,
        use_trm: bool = True,
        use_mcts: bool = False,
        mcts_iterations: int = 25,
        exploration_weight: float = 1.414,
        seed: int | None = None,
    ) -> FrameworkResult:
        """Process a query through the multi-agent framework.

        Args:
            query: The input query to process
            use_hrm: Enable Hierarchical Reasoning Module
            use_trm: Enable Tree Reasoning Module
            use_mcts: Enable Monte Carlo Tree Search
            mcts_iterations: Number of MCTS iterations
            exploration_weight: UCB1 exploration parameter
            seed: Random seed for reproducibility

        Returns:
            FrameworkResult with all agent outputs and consensus
        """
        start_time = time.perf_counter()

        hrm_result = None
        trm_result = None
        mcts_result = None

        # Run enabled agents
        tasks = []
        agent_names = []

        if use_hrm:
            tasks.append(self._run_hrm(query))
            agent_names.append("hrm")

        if use_trm:
            tasks.append(self._run_trm(query))
            agent_names.append("trm")

        if use_mcts:
            tasks.append(self._run_mcts(query, mcts_iterations, exploration_weight, seed))
            agent_names.append("mcts")

        # Execute agents concurrently
        if tasks:
            results = await asyncio.gather(*tasks, return_exceptions=True)

            for name, result in zip(agent_names, results, strict=False):
                if isinstance(result, Exception):
                    continue
                if name == "hrm":
                    hrm_result = result
                elif name == "trm":
                    trm_result = result
                elif name == "mcts":
                    mcts_result = result

        # Calculate consensus score
        consensus_score = self._calculate_consensus(hrm_result, trm_result, mcts_result)

        # Generate final synthesized response
        final_response = self._synthesize_response(query, hrm_result, trm_result, mcts_result, consensus_score)

        total_time = (time.perf_counter() - start_time) * 1000

        return FrameworkResult(
            query=query,
            hrm_result=hrm_result,
            trm_result=trm_result,
            mcts_result=mcts_result,
            consensus_score=consensus_score,
            final_response=final_response,
            total_time_ms=round(total_time, 2),
            metadata={
                "agents_used": agent_names,
                "mcts_config": (
                    {"iterations": mcts_iterations, "exploration_weight": exploration_weight, "seed": seed}
                    if use_mcts
                    else None
                ),
            },
        )

    async def _run_hrm(self, query: str) -> AgentResult:
        """Run Hierarchical Reasoning Module."""
        start = time.perf_counter()
        result = await self.hrm_agent.process(query)
        elapsed = (time.perf_counter() - start) * 1000

        return AgentResult(
            agent_name="HRM (Hierarchical Reasoning)",
            response=result["response"],
            confidence=result["confidence"],
            reasoning_steps=result["steps"],
            execution_time_ms=round(elapsed, 2),
        )

    async def _run_trm(self, query: str) -> AgentResult:
        """Run Tree Reasoning Module."""
        start = time.perf_counter()
        result = await self.trm_agent.process(query)
        elapsed = (time.perf_counter() - start) * 1000

        return AgentResult(
            agent_name="TRM (Iterative Refinement)",
            response=result["response"],
            confidence=result["confidence"],
            reasoning_steps=result["steps"],
            execution_time_ms=round(elapsed, 2),
        )

    async def _run_mcts(self, query: str, iterations: int, exploration_weight: float, seed: int | None) -> dict:
        """Run Monte Carlo Tree Search."""
        start = time.perf_counter()

        # MCTSDemo.search is now async and uses the production framework
        result = await self.mcts.search(query=query, iterations=iterations, exploration_weight=exploration_weight, seed=seed)

        elapsed = (time.perf_counter() - start) * 1000
        result["execution_time_ms"] = round(elapsed, 2)

        return result

    def _calculate_consensus(
        self, hrm_result: AgentResult | None, trm_result: AgentResult | None, mcts_result: dict | None
    ) -> float:
        """Calculate agreement score between agents."""
        confidences = []

        if hrm_result:
            confidences.append(hrm_result.confidence)
        if trm_result:
            confidences.append(trm_result.confidence)
        if mcts_result:
            confidences.append(mcts_result.get("best_value", 0.5))

        if not confidences:
            return 0.0

        # Consensus is based on confidence alignment and average
        if len(confidences) == 1:
            return confidences[0]

        avg_confidence = np.mean(confidences)
        std_confidence = np.std(confidences)

        # Higher consensus when agents agree (low std) and are confident (high avg)
        agreement_factor = max(0, 1 - std_confidence * 2)
        consensus = avg_confidence * agreement_factor

        return round(min(1.0, consensus), 3)

    def _synthesize_response(
        self,
        query: str,
        hrm_result: AgentResult | None,
        trm_result: AgentResult | None,
        mcts_result: dict | None,
        consensus_score: float,
    ) -> str:
        """Synthesize final response from all agent outputs."""
        parts = []

        if hrm_result and hrm_result.confidence > 0.5:
            parts.append(f"[HRM] {hrm_result.response}")

        if trm_result and trm_result.confidence > 0.5:
            parts.append(f"[TRM] {trm_result.response}")

        if mcts_result and mcts_result.get("best_value", 0) > 0.5:
            parts.append(f"[MCTS] Best path: {mcts_result.get('best_action', 'N/A')}")

        if not parts:
            truncated_query = f"{query[:80]}..." if len(query) > 80 else query
            return f"Insufficient confidence to answer query: '{truncated_query}'."

        synthesis = " | ".join(parts)

        if consensus_score > 0.7:
            return f"HIGH CONSENSUS ({consensus_score:.1%}): {synthesis}"
        elif consensus_score > 0.4:
            return f"MODERATE CONSENSUS ({consensus_score:.1%}): {synthesis}"
        else:
            return f"LOW CONSENSUS ({consensus_score:.1%}): {synthesis}"


# Global framework instance
framework = None
wandb_tracker = None


def initialize_framework(use_hf: bool, model_id: str):
    """Initialize or reinitialize the framework."""
    global framework
    framework = MultiAgentFrameworkDemo(use_hf_inference=use_hf, hf_model=model_id)
    return "Framework initialized successfully!"


def process_query_sync(
    query: str,
    use_hrm: bool,
    use_trm: bool,
    use_mcts: bool,
    mcts_iterations: int,
    exploration_weight: float,
    seed: int,
    enable_wandb: bool = False,
    wandb_project: str = "langgraph-mcts-demo",
    wandb_run_name: str = "",
):
    """Synchronous wrapper for async processing."""
    global framework, wandb_tracker

    if framework is None:
        framework = MultiAgentFrameworkDemo()

    if not query.strip():
        return "Please enter a query.", {}, "", {}, ""

    # Handle seed
    seed_value = seed if seed > 0 else None

    # Initialize W&B tracking if enabled
    wandb_url = ""
    if enable_wandb and is_wandb_available():
        if wandb_tracker is None:
            wandb_tracker = WandBTracker(project_name=wandb_project, enabled=True)

        # Start a new run
        run_name = wandb_run_name if wandb_run_name.strip() else None
        config = {
            "query": query[:200],  # Truncate for config
            "use_hrm": use_hrm,
            "use_trm": use_trm,
            "use_mcts": use_mcts,
            "mcts_iterations": mcts_iterations,
            "exploration_weight": exploration_weight,
            "seed": seed_value,
        }
        wandb_tracker.init_run(run_name=run_name, config=config)

    # Run async function
    result = asyncio.run(
        framework.process_query(
            query=query,
            use_hrm=use_hrm,
            use_trm=use_trm,
            use_mcts=use_mcts,
            mcts_iterations=int(mcts_iterations),
            exploration_weight=exploration_weight,
            seed=seed_value,
        )
    )

    # Format outputs
    final_response = result.final_response

    # Agent details
    agent_details = {}
    if result.hrm_result:
        agent_details["HRM"] = {
            "response": result.hrm_result.response,
            "confidence": f"{result.hrm_result.confidence:.1%}",
            "reasoning_steps": result.hrm_result.reasoning_steps,
            "time_ms": result.hrm_result.execution_time_ms,
        }

        # Log to W&B
        if enable_wandb and wandb_tracker:
            wandb_tracker.log_agent_result(
                "HRM",
                result.hrm_result.response,
                result.hrm_result.confidence,
                result.hrm_result.execution_time_ms,
                result.hrm_result.reasoning_steps,
            )

    if result.trm_result:
        agent_details["TRM"] = {
            "response": result.trm_result.response,
            "confidence": f"{result.trm_result.confidence:.1%}",
            "reasoning_steps": result.trm_result.reasoning_steps,
            "time_ms": result.trm_result.execution_time_ms,
        }

        # Log to W&B
        if enable_wandb and wandb_tracker:
            wandb_tracker.log_agent_result(
                "TRM",
                result.trm_result.response,
                result.trm_result.confidence,
                result.trm_result.execution_time_ms,
                result.trm_result.reasoning_steps,
            )

    if result.mcts_result:
        agent_details["MCTS"] = result.mcts_result

        # Log to W&B
        if enable_wandb and wandb_tracker:
            wandb_tracker.log_mcts_result(result.mcts_result)

    # Log consensus and performance to W&B
    if enable_wandb and wandb_tracker:
        wandb_tracker.log_consensus(result.consensus_score, result.metadata["agents_used"], result.final_response)
        wandb_tracker.log_performance(result.total_time_ms)
        wandb_tracker.log_query_summary(query, use_hrm, use_trm, use_mcts, result.consensus_score, result.total_time_ms)

        # Get run URL
        wandb_url = wandb_tracker.get_run_url() or ""

        # Finish the run
        wandb_tracker.finish_run()

    # Metrics
    metrics = f"""
**Consensus Score:** {result.consensus_score:.1%}
**Total Processing Time:** {result.total_time_ms:.2f} ms
**Agents Used:** {", ".join(result.metadata["agents_used"])}
"""

    if wandb_url:
        metrics += f"\n**W&B Run:** [{wandb_url}]({wandb_url})"

    # Full JSON result
    full_result = {
        "query": result.query,
        "final_response": result.final_response,
        "consensus_score": result.consensus_score,
        "total_time_ms": result.total_time_ms,
        "metadata": result.metadata,
        "agent_details": agent_details,
        "wandb_url": wandb_url,
    }

    return final_response, agent_details, metrics, full_result, wandb_url


def visualize_mcts_tree(mcts_result: dict) -> str:
    """Create ASCII visualization of MCTS tree."""
    if not mcts_result or "tree_visualization" not in mcts_result:
        return "No MCTS tree data available"

    return mcts_result["tree_visualization"]


# Example queries for demonstration
EXAMPLE_QUERIES = [
    "What are the key factors to consider when choosing between microservices and monolithic architecture?",
    "How can we optimize a Python application that processes 10GB of log files daily?",
    "What is the best approach to implement rate limiting in a distributed system?",
    "Should we use SQL or NoSQL database for a social media application with 1M users?",
    "How to design a fault-tolerant message queue system?",
]


# Gradio Interface
with gr.Blocks(
    title="LangGraph Multi-Agent MCTS Demo",
    theme=gr.themes.Soft(),
    css="""
    .agent-box { border: 1px solid #ddd; padding: 10px; border-radius: 5px; margin: 5px 0; }
    .consensus-high { color: #28a745; font-weight: bold; }
    .consensus-medium { color: #ffc107; font-weight: bold; }
    .consensus-low { color: #dc3545; font-weight: bold; }
    """,
) as demo:
    gr.Markdown(
        """
        # LangGraph Multi-Agent MCTS Framework

        **Proof-of-Concept Demo** - Multi-agent reasoning with Monte Carlo Tree Search

        This demo showcases:
        - **HRM**: Hierarchical Reasoning Module - breaks down complex queries
        - **TRM**: Tree Reasoning Module - iterative refinement of responses
        - **MCTS**: Monte Carlo Tree Search - strategic exploration of solution space
        - **Consensus**: Agreement scoring between agents

        ---
        """
    )

    with gr.Row():
        with gr.Column(scale=2):
            query_input = gr.Textbox(
                label="Query", placeholder="Enter your reasoning task or question...", lines=3, max_lines=10
            )

            gr.Markdown("**Example Queries:**")
            example_dropdown = gr.Dropdown(choices=EXAMPLE_QUERIES, label="Select an example", interactive=True)

            def load_example(example):
                return example

            example_dropdown.change(load_example, example_dropdown, query_input)

        with gr.Column(scale=1):
            gr.Markdown("**Agent Configuration**")
            use_hrm = gr.Checkbox(label="Enable HRM (Hierarchical)", value=True)
            use_trm = gr.Checkbox(label="Enable TRM (Iterative)", value=True)
            use_mcts = gr.Checkbox(label="Enable MCTS", value=False)

            gr.Markdown("**MCTS Parameters**")
            mcts_iterations = gr.Slider(
                minimum=10,
                maximum=100,
                value=25,
                step=5,
                label="Iterations",
                info="More iterations = better search, but slower",
            )
            exploration_weight = gr.Slider(
                minimum=0.1,
                maximum=3.0,
                value=1.414,
                step=0.1,
                label="Exploration Weight (C)",
                info="Higher = more exploration, Lower = more exploitation",
            )
            seed_input = gr.Number(label="Random Seed (0 for random)", value=0, precision=0)

    with gr.Accordion("Weights & Biases Tracking", open=False):
        gr.Markdown(
            """
            **Experiment Tracking with W&B**

            Track your experiments, visualize metrics, and compare runs.
            Requires W&B API key set in Space secrets as `WANDB_API_KEY`.
            """
        )
        with gr.Row():
            enable_wandb = gr.Checkbox(
                label="Enable W&B Tracking", value=False, info="Log metrics and results to Weights & Biases"
            )
            wandb_project = gr.Textbox(
                label="Project Name", value="langgraph-mcts-demo", placeholder="Your W&B project name"
            )
            wandb_run_name = gr.Textbox(label="Run Name (optional)", value="", placeholder="Auto-generated if empty")

        wandb_status = gr.Markdown(f"**W&B Status:** {'Available' if is_wandb_available() else 'Not installed'}")

    process_btn = gr.Button("Process Query", variant="primary", size="lg")

    gr.Markdown("---")

    with gr.Row():
        with gr.Column():
            gr.Markdown("### Final Response")
            final_response_output = gr.Textbox(label="Synthesized Response", lines=4, interactive=False)

            gr.Markdown("### Performance Metrics")
            metrics_output = gr.Markdown()

        with gr.Column():
            gr.Markdown("### Agent Details")
            agent_details_output = gr.JSON(label="Individual Agent Results")

    with gr.Accordion("Full JSON Result", open=False):
        full_result_output = gr.JSON(label="Complete Framework Output")

    with gr.Accordion("W&B Run Details", open=False, visible=True):
        wandb_url_output = gr.Textbox(
            label="W&B Run URL", interactive=False, placeholder="Enable W&B tracking to see run URL here"
        )

    # Wire up the processing
    process_btn.click(
        fn=process_query_sync,
        inputs=[
            query_input,
            use_hrm,
            use_trm,
            use_mcts,
            mcts_iterations,
            exploration_weight,
            seed_input,
            enable_wandb,
            wandb_project,
            wandb_run_name,
        ],
        outputs=[final_response_output, agent_details_output, metrics_output, full_result_output, wandb_url_output],
    )

    gr.Markdown(
        """
        ---

        ### About This Demo

        This is a **proof-of-concept** demonstration of the LangGraph Multi-Agent MCTS Framework.

        **Features:**
        - Multi-agent orchestration with consensus scoring
        - Monte Carlo Tree Search for strategic reasoning
        - Configurable exploration vs exploitation trade-offs
        - Deterministic results with seeded randomness
        - **Weights & Biases integration** for experiment tracking

        **Limitations (POC):**
        - Uses mock/simplified LLM responses (not production LLM)
        - Limited to demonstration scenarios
        - No persistent storage or RAG
        - Simplified MCTS implementation

        **Full Framework:** [GitHub Repository](https://github.com/ianshank/langgraph_multi_agent_mcts)

        ---
        *Built with LangGraph, Gradio, Weights & Biases, and Python*
        """
    )


if __name__ == "__main__":
    # Initialize with mock client for demo
    framework = MultiAgentFrameworkDemo(use_hf_inference=False)

    # Launch the demo
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True)