File size: 14,559 Bytes
40ee6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
"""
Input validation models for LangGraph Multi-Agent MCTS framework.

Provides:
- Pydantic models for all external inputs
- Query sanitization and length limits
- Configuration validation
- MCP tool input validation with strict type checking
- Security-focused input processing
"""

import re
from datetime import datetime
from typing import Any

from pydantic import (
    BaseModel,
    ConfigDict,
    Field,
    field_validator,
    model_validator,
)

# Constants for validation
MAX_QUERY_LENGTH = 10000
MIN_QUERY_LENGTH = 1
MAX_CONTEXT_LENGTH = 50000
MAX_ITERATIONS = 10000
MIN_ITERATIONS = 1
MAX_EXPLORATION_WEIGHT = 10.0
MIN_EXPLORATION_WEIGHT = 0.0
MAX_BATCH_SIZE = 100


class QueryInput(BaseModel):
    """
    Validated query input for the multi-agent framework.

    Performs sanitization and security checks on user queries.
    """

    model_config = ConfigDict(
        strict=True,
        validate_assignment=True,
        extra="forbid",
    )

    query: str = Field(
        ..., min_length=MIN_QUERY_LENGTH, max_length=MAX_QUERY_LENGTH, description="User query to process"
    )

    use_rag: bool = Field(default=True, description="Enable RAG context retrieval")

    use_mcts: bool = Field(default=False, description="Enable MCTS simulation for tactical planning")

    thread_id: str | None = Field(
        default=None,
        max_length=100,
        pattern=r"^[a-zA-Z0-9_-]+$",
        description="Conversation thread ID for state persistence",
    )

    @field_validator("query")
    @classmethod
    def sanitize_query(cls, v: str) -> str:
        """
        Sanitize query input for security.

        Removes potentially dangerous patterns while preserving legitimate content.
        """
        # Strip leading/trailing whitespace
        v = v.strip()

        # Check for empty query after stripping
        if not v:
            raise ValueError("Query cannot be empty or contain only whitespace")

        # Remove null bytes
        v = v.replace("\x00", "")

        # Limit consecutive whitespace
        v = re.sub(r"\s+", " ", v)

        # Check for suspicious patterns (basic injection prevention)
        suspicious_patterns = [
            r"<script[^>]*>",  # Script tags
            r"javascript:",  # JavaScript URLs
            r"on\w+\s*=",  # Event handlers
            r"\{\{.*\}\}",  # Template injection
            r"\$\{.*\}",  # Template literals
        ]

        for pattern in suspicious_patterns:
            if re.search(pattern, v, re.IGNORECASE):
                raise ValueError(f"Query contains potentially unsafe content matching pattern: {pattern}")

        return v

    @field_validator("thread_id")
    @classmethod
    def validate_thread_id(cls, v: str | None) -> str | None:
        """Validate thread ID format for safe storage keys."""
        if v is not None:  # noqa: SIM102
            # Additional safety check beyond pattern
            if ".." in v or "/" in v or "\\" in v:
                raise ValueError("Thread ID contains invalid path characters")
        return v


class MCTSConfig(BaseModel):
    """
    Validated MCTS configuration parameters.

    Enforces bounds on exploration weight and iteration counts.
    """

    model_config = ConfigDict(
        strict=True,
        extra="forbid",
    )

    iterations: int = Field(
        default=100, ge=MIN_ITERATIONS, le=MAX_ITERATIONS, description="Number of MCTS simulation iterations"
    )

    exploration_weight: float = Field(
        default=1.414,
        ge=MIN_EXPLORATION_WEIGHT,
        le=MAX_EXPLORATION_WEIGHT,
        description="UCB1 exploration constant (c parameter)",
    )

    max_depth: int = Field(default=10, ge=1, le=50, description="Maximum tree depth for MCTS expansion")

    simulation_timeout_seconds: float = Field(
        default=30.0, ge=1.0, le=300.0, description="Timeout for MCTS simulation phase"
    )

    @field_validator("exploration_weight")
    @classmethod
    def validate_exploration_weight(cls, v: float) -> float:
        """Validate exploration weight is within reasonable bounds."""
        if not (MIN_EXPLORATION_WEIGHT <= v <= MAX_EXPLORATION_WEIGHT):
            raise ValueError(
                f"Exploration weight must be between {MIN_EXPLORATION_WEIGHT} and {MAX_EXPLORATION_WEIGHT}"
            )
        # Warn for unusual values
        if v < 0.5 or v > 3.0:
            import warnings

            warnings.warn(
                f"Exploration weight {v} is outside typical range (0.5-3.0). "
                "This may lead to suboptimal search behavior.",
                UserWarning,
                stacklevel=2,
            )
        return v


class AgentConfig(BaseModel):
    """
    Validated configuration for HRM/TRM agents.
    """

    model_config = ConfigDict(
        extra="forbid",
    )

    max_iterations: int = Field(default=3, ge=1, le=20, description="Maximum iterations for agent refinement")

    consensus_threshold: float = Field(
        default=0.75, ge=0.0, le=1.0, description="Consensus threshold for agent agreement"
    )

    temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="LLM temperature for response generation")

    max_tokens: int = Field(default=2048, ge=1, le=128000, description="Maximum tokens in LLM response")

    @field_validator("temperature")
    @classmethod
    def validate_temperature(cls, v: float) -> float:
        """Validate temperature is within LLM bounds."""
        if v < 0.0 or v > 2.0:
            raise ValueError("Temperature must be between 0.0 and 2.0")
        return v


class RAGConfig(BaseModel):
    """
    Validated RAG (Retrieval Augmented Generation) configuration.
    """

    model_config = ConfigDict(
        extra="forbid",
    )

    top_k: int = Field(default=5, ge=1, le=50, description="Number of documents to retrieve")

    similarity_threshold: float = Field(
        default=0.5, ge=0.0, le=1.0, description="Minimum similarity score for retrieved documents"
    )

    chunk_size: int = Field(default=1000, ge=100, le=10000, description="Document chunk size for embedding")

    chunk_overlap: int = Field(default=200, ge=0, le=2000, description="Overlap between document chunks")

    @model_validator(mode="after")
    def validate_chunk_overlap(self) -> "RAGConfig":
        """Ensure chunk overlap is less than chunk size."""
        if self.chunk_overlap >= self.chunk_size:
            raise ValueError("Chunk overlap must be less than chunk size")
        return self


class MCPToolInput(BaseModel):
    """
    Base validation model for MCP (Model Context Protocol) tool inputs.

    Provides strict validation for external tool invocations.
    """

    model_config = ConfigDict(
        strict=True,
        extra="forbid",
    )

    tool_name: str = Field(
        ...,
        min_length=1,
        max_length=100,
        pattern=r"^[a-zA-Z][a-zA-Z0-9_-]*$",
        description="Name of the MCP tool to invoke",
    )

    parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters as key-value pairs")

    timeout_seconds: float = Field(default=30.0, ge=1.0, le=300.0, description="Timeout for tool execution")

    @field_validator("tool_name")
    @classmethod
    def validate_tool_name(cls, v: str) -> str:
        """Validate tool name is safe and follows naming conventions."""
        # Prevent path traversal in tool names
        if ".." in v or "/" in v or "\\" in v:
            raise ValueError("Tool name contains invalid characters")

        # Prevent overly long names
        if len(v) > 100:
            raise ValueError("Tool name exceeds maximum length of 100 characters")

        return v

    @field_validator("parameters")
    @classmethod
    def validate_parameters(cls, v: dict[str, Any]) -> dict[str, Any]:
        """Validate tool parameters for security."""
        # Check for reasonable size
        if len(str(v)) > 100000:
            raise ValueError("Tool parameters exceed maximum size")

        # Check parameter count
        if len(v) > 50:
            raise ValueError("Too many parameters (maximum 50)")

        # Validate parameter keys
        for key in v:
            if not isinstance(key, str):
                raise ValueError("Parameter keys must be strings")
            if len(key) > 100:
                raise ValueError(f"Parameter key '{key[:20]}...' exceeds maximum length")
            if not re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", key):
                raise ValueError(f"Invalid parameter key format: {key}")

        return v


class FileReadInput(MCPToolInput):
    """
    Validated input for file reading operations.

    Implements path traversal protection.
    """

    tool_name: str = Field(default="read_file", frozen=True)

    file_path: str = Field(..., min_length=1, max_length=1000, description="Path to file to read")

    @field_validator("file_path")
    @classmethod
    def validate_file_path(cls, v: str) -> str:
        """Validate file path for security concerns."""
        # Normalize path
        v = v.strip()

        # Check for path traversal attempts
        if ".." in v:
            raise ValueError("Path traversal detected: '..' not allowed in file path")

        # Check for absolute paths (may be allowed in some contexts)
        if v.startswith("/"):
            import warnings

            warnings.warn(
                "Absolute file path provided. Ensure this is within allowed directories.", UserWarning, stacklevel=2
            )

        # Check for suspicious patterns
        suspicious = [
            "/etc/",
            "/root/",
            "~/.ssh/",
            "/var/",
            "\\windows\\",
            "\\system32\\",
        ]
        for pattern in suspicious:
            if pattern.lower() in v.lower():
                raise ValueError(f"File path contains restricted directory: {pattern}")

        return v


class WebFetchInput(MCPToolInput):
    """
    Validated input for web fetch operations.

    Implements URL validation and security checks.
    """

    tool_name: str = Field(default="web_fetch", frozen=True)

    url: str = Field(..., min_length=1, max_length=2000, description="URL to fetch")

    @field_validator("url")
    @classmethod
    def validate_url(cls, v: str) -> str:
        """Validate URL for security."""
        v = v.strip()

        # Must start with https:// for security (http:// only for local)
        if not v.startswith(("https://", "http://localhost", "http://127.0.0.1")):
            raise ValueError("URL must use HTTPS protocol (except for localhost)")

        # Check for suspicious patterns
        if any(char in v for char in ["<", ">", "'", '"', ";"]):
            raise ValueError("URL contains invalid characters")

        # Validate basic URL structure
        url_pattern = r"^https?://[^\s/$.?#].[^\s]*$"
        if not re.match(url_pattern, v, re.IGNORECASE):
            raise ValueError("Invalid URL format")

        return v


class BatchQueryInput(BaseModel):
    """
    Validated batch query input for processing multiple queries.
    """

    model_config = ConfigDict(
        strict=True,
        extra="forbid",
    )

    queries: list[QueryInput] = Field(
        ..., min_length=1, max_length=MAX_BATCH_SIZE, description="List of queries to process in batch"
    )

    parallel: bool = Field(default=False, description="Process queries in parallel (if system supports)")

    @field_validator("queries")
    @classmethod
    def validate_batch_size(cls, v: list[QueryInput]) -> list[QueryInput]:
        """Validate batch doesn't exceed limits."""
        if len(v) > MAX_BATCH_SIZE:
            raise ValueError(f"Batch size exceeds maximum of {MAX_BATCH_SIZE}")
        if len(v) == 0:
            raise ValueError("Batch must contain at least one query")
        return v


class APIRequestMetadata(BaseModel):
    """
    Metadata for API request tracking and audit logging.

    Used for security monitoring and rate limiting.
    """

    model_config = ConfigDict(
        extra="forbid",
    )

    request_id: str = Field(
        ..., min_length=1, max_length=100, pattern=r"^[a-zA-Z0-9_-]+$", description="Unique request identifier"
    )

    timestamp: datetime = Field(default_factory=datetime.utcnow, description="Request timestamp (UTC)")

    client_id: str | None = Field(
        default=None, max_length=100, pattern=r"^[a-zA-Z0-9_-]+$", description="Client identifier for rate limiting"
    )

    source_ip: str | None = Field(default=None, description="Source IP address (for audit logging)")

    @field_validator("source_ip")
    @classmethod
    def validate_ip_address(cls, v: str | None) -> str | None:
        """Validate IP address format."""
        if v is not None:
            # Basic IPv4/IPv6 validation
            import ipaddress

            try:
                ipaddress.ip_address(v)
            except ValueError:
                raise ValueError(f"Invalid IP address format: {v}")
        return v


# Convenience functions for common validation patterns


def validate_query(query: str, **kwargs) -> QueryInput:
    """
    Validate a query string and return a validated QueryInput model.

    Args:
        query: Raw query string
        **kwargs: Additional query parameters

    Returns:
        QueryInput: Validated query model

    Raises:
        ValidationError: If validation fails
    """
    return QueryInput(query=query, **kwargs)


def validate_mcts_config(**kwargs) -> MCTSConfig:
    """
    Validate MCTS configuration parameters.

    Args:
        **kwargs: MCTS configuration parameters

    Returns:
        MCTSConfig: Validated configuration

    Raises:
        ValidationError: If validation fails
    """
    return MCTSConfig(**kwargs)


def validate_tool_input(tool_name: str, parameters: dict[str, Any], **kwargs) -> MCPToolInput:
    """
    Validate MCP tool input parameters.

    Args:
        tool_name: Name of the tool
        parameters: Tool parameters
        **kwargs: Additional options

    Returns:
        MCPToolInput: Validated tool input

    Raises:
        ValidationError: If validation fails
    """
    return MCPToolInput(tool_name=tool_name, parameters=parameters, **kwargs)


# Type exports
__all__ = [
    "QueryInput",
    "MCTSConfig",
    "AgentConfig",
    "RAGConfig",
    "MCPToolInput",
    "FileReadInput",
    "WebFetchInput",
    "BatchQueryInput",
    "APIRequestMetadata",
    "validate_query",
    "validate_mcts_config",
    "validate_tool_input",
]